Implement lower/upper triangular solve for CSC matrices
This commit is contained in:
parent
c988ebb4e7
commit
3b1303d1e0
@ -4,7 +4,7 @@ mod impl_std_ops;
|
||||
pub mod serial;
|
||||
|
||||
/// TODO
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Op<T> {
|
||||
/// TODO
|
||||
NoOp(T),
|
||||
|
@ -2,7 +2,7 @@ use crate::csc::CscMatrix;
|
||||
use crate::ops::Op;
|
||||
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
||||
use crate::ops::serial::OperationError;
|
||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice, RealField};
|
||||
use num_traits::{Zero, One};
|
||||
|
||||
use std::borrow::Cow;
|
||||
@ -89,4 +89,158 @@ pub fn spmm_csc_prealloc<T>(
|
||||
spmm_csc_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum SolveErrorKind {
|
||||
/// TODO
|
||||
Singular,
|
||||
}
|
||||
|
||||
/// TODO
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct SolveError {
|
||||
kind: SolveErrorKind,
|
||||
message: String
|
||||
}
|
||||
|
||||
impl SolveError {
|
||||
fn from_type_and_message(kind: SolveErrorKind, message: String) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
message
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solve the lower triangular system `op(L) X = B`.
|
||||
///
|
||||
/// Only the lower triangular part of L is read, and the result is stored in B.
|
||||
///
|
||||
/// ## Panics
|
||||
///
|
||||
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
|
||||
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
||||
l: Op<&CscMatrix<T>>,
|
||||
b: impl Into<DMatrixSliceMut<'a, T>>)
|
||||
-> Result<(), SolveError>
|
||||
{
|
||||
let b = b.into();
|
||||
let l_matrix = l.into_inner();
|
||||
assert_eq!(l_matrix.nrows(), l_matrix.ncols(), "Matrix must be square for triangular solve.");
|
||||
assert_eq!(l_matrix.nrows(), b.nrows(), "Dimension mismatch in sparse lower triangular solver.");
|
||||
match l {
|
||||
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
|
||||
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
|
||||
}
|
||||
}
|
||||
|
||||
fn spsolve_csc_lower_triangular_no_transpose<'a, T: RealField>(
|
||||
l: &CscMatrix<T>,
|
||||
b: DMatrixSliceMut<'a, T>)
|
||||
-> Result<(), SolveError>
|
||||
{
|
||||
let mut x = b;
|
||||
|
||||
// Solve column-by-column
|
||||
for j in 0 .. x.ncols() {
|
||||
let mut x_col_j = x.column_mut(j);
|
||||
|
||||
for k in 0 .. l.ncols() {
|
||||
let l_col_k = l.col(k);
|
||||
|
||||
// Skip entries above the diagonal
|
||||
// TODO: Can use exponential search here to quickly skip entries
|
||||
// (we'd like to avoid using binary search as it's very cache unfriendly
|
||||
// and the matrix might actually *be* lower triangular, which would induce
|
||||
// a severe penalty)
|
||||
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
|
||||
if let Some(diag_csc_index) = diag_csc_index {
|
||||
let l_kk = l_col_k.values()[diag_csc_index];
|
||||
|
||||
if l_kk != T::zero() {
|
||||
// Update entry associated with diagonal
|
||||
x_col_j[k] /= l_kk;
|
||||
// Copy value after updating (so we don't run into the borrow checker)
|
||||
let x_kj = x_col_j[k];
|
||||
|
||||
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1) ..];
|
||||
let l_values = &l_col_k.values()[(diag_csc_index + 1) ..];
|
||||
|
||||
// Note: The remaining entries are below the diagonal
|
||||
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
||||
let x_ij = &mut x_col_j[i];
|
||||
*x_ij -= l_ik.inlined_clone() * x_kj;
|
||||
}
|
||||
|
||||
x_col_j[k] = x_kj;
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spsolve_encountered_zero_diagonal() -> Result<(), SolveError> {
|
||||
let message = "Matrix contains at least one diagonal entry that is zero.";
|
||||
Err(SolveError::from_type_and_message(SolveErrorKind::Singular, String::from(message)))
|
||||
}
|
||||
|
||||
fn spsolve_csc_lower_triangular_transpose<'a, T: RealField>(
|
||||
l: &CscMatrix<T>,
|
||||
b: DMatrixSliceMut<'a, T>)
|
||||
-> Result<(), SolveError>
|
||||
{
|
||||
let mut x = b;
|
||||
|
||||
// Solve column-by-column
|
||||
for j in 0 .. x.ncols() {
|
||||
let mut x_col_j = x.column_mut(j);
|
||||
|
||||
// Due to the transposition, we're essentially solving an upper triangular system,
|
||||
// and the columns in our matrix become rows
|
||||
|
||||
for i in (0 .. l.ncols()).rev() {
|
||||
let l_col_i = l.col(i);
|
||||
|
||||
// Skip entries above the diagonal
|
||||
// TODO: Can use exponential search here to quickly skip entries
|
||||
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
|
||||
if let Some(diag_csc_index) = diag_csc_index {
|
||||
let l_ii = l_col_i.values()[diag_csc_index];
|
||||
|
||||
if l_ii != T::zero() {
|
||||
// // Update entry associated with diagonal
|
||||
// x_col_j[k] /= a_kk;
|
||||
|
||||
// Copy value after updating (so we don't run into the borrow checker)
|
||||
let mut x_ii = x_col_j[i];
|
||||
|
||||
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1) ..];
|
||||
let a_values = &l_col_i.values()[(diag_csc_index + 1) ..];
|
||||
|
||||
// Note: The remaining entries are below the diagonal
|
||||
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
||||
let x_kj = x_col_j[k];
|
||||
x_ii -= l_ki * x_kj;
|
||||
}
|
||||
|
||||
x_col_j[i] = x_ii / l_ii;
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -10,3 +10,5 @@ cc dbaef9886eaad28be7cd48326b857f039d695bc0b19e9ada3304e812e984d2c3 # shrinks to
|
||||
cc 99e312beb498ffa79194f41501ea312dce1911878eba131282904ac97205aaa9 # shrinks to SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrDenseArgs { c: Matrix { data: VecStorage { data: [-1, 4, -1, -4, 2, 1, 4, -2, 1, 3, -2, 5], nrows: Dynamic { value: 2 }, ncols: Dynamic { value: 6 } } }, beta: 0, alpha: 0, trans_a: Transpose, a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 2 }, values: [0] }, trans_b: Transpose, b: Matrix { data: VecStorage { data: [-1, 1, 0, -5, 4, -5, 2, 2, 4, -4, -3, -1, 1, -1, 0, 1, -3, 4, -5, 0, 1, -5, 0, 1, 1, -3, 5, 3, 5, -3, -5, 3, -1, -4, -4, -3], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 6 } } } }
|
||||
cc bf74259df2db6eda24eb42098e57ea1c604bb67d6d0023fa308c321027b53a43 # shrinks to (alpha, beta, c, a, b, trans_a, trans_b) = (0, 0, Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 3, 6, 9, 12], minor_indices: [0, 1, 3, 1, 2, 3, 0, 1, 2, 1, 2, 3], minor_dim: 4 }, values: [-3, 3, -3, 1, -3, 0, 2, 1, 3, 0, -4, -1] }, Matrix { data: VecStorage { data: [3, 1, 4, -5, 5, -2, -5, -1, 1, -1, 3, -3, -2, 4, 2, -1, -1, 3, -5, 5], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, NoTranspose, NoTranspose)
|
||||
cc cbd6dac45a2f610e10cf4c15d4614cdbf7dfedbfcd733e4cc65c2e79829d14b3 # shrinks to SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrArgs { c: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 1, 1, 1, 1], minor_indices: [0], minor_dim: 1 }, values: [0] }, beta: 0, alpha: 1, trans_a: Transpose(true), a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 0, 1, 1, 1], minor_indices: [1], minor_dim: 5 }, values: [-1] }, trans_b: Transpose(true), b: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2], minor_indices: [2, 4], minor_dim: 5 }, values: [-1, 0] } }
|
||||
cc 8af78e2e41087743c8696c4d5563d59464f284662ccf85efc81ac56747d528bb # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 6, 12, 18, 24, 30, 33], minor_indices: [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 1, 2, 5], minor_dim: 6 }, values: [0.4566433975117654, -0.5109683327713039, 0.0, -3.276901622678194, 0.0, -2.2065487385437095, 0.0, -0.42643054427847016, -2.9232369281581234, 0.0, 1.2913925579441763, 0.0, -1.4073766622090917, -4.795473113569459, 4.681765156869446, -0.821162215887913, 3.0315816068414794, -3.3986924718213407, -3.498903007282241, -3.1488953408335236, 3.458104636152161, -4.774694888508124, 2.603884664757498, 0.0, 0.0, -3.2650988857765535, 4.26699442646613, 0.0, -0.012223422086023561, 3.6899095325779285, -1.4264458042247958, 0.0, 3.4849193883471266] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9513896933988457, -4.426942420881461, 0.0, 0.0, 0.0, -0.28264084049240257], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 2 } } })
|
||||
cc a4effd988fe352146fca365875e108ecf4f7d41f6ad54683e923ca6ce712e5d0 # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 5, 11, 17, 22, 27, 31], minor_indices: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 3, 4, 5, 1, 2, 3, 4, 5, 0, 1, 3, 5], minor_dim: 6 }, values: [-2.24935510943371, -2.2288203680206227, 0.0, -1.029740125494273, 0.0, 0.0, 0.22632926934348507, -0.9123245943877407, 0.0, 3.8564332876991827, 0.0, 0.0, 0.0, -0.8235065737081717, 1.9337984046721566, 0.11003468246027737, -3.422112890579867, -3.7824068893569196, 0.0, -0.021700572247226546, -4.914783069982362, 0.6227245544506541, 0.0, 0.0, -4.411368879922364, -0.00013623178651567258, -2.613658177661417, -2.2783292441548637, 0.0, 1.351859435890189, -0.021345159183605134] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -4.519417607973404, 0.0, 0.0, 0.0, -0.21238483334481817], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 3 } } })
|
||||
|
@ -1,8 +1,8 @@
|
||||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||
PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy};
|
||||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
|
||||
spadd_csr_prealloc, spadd_csc_prealloc,
|
||||
spmm_csr_prealloc, spmm_csc_prealloc};
|
||||
spmm_csr_prealloc, spmm_csc_prealloc,
|
||||
spsolve_csc_lower_triangular};
|
||||
use nalgebra_sparse::ops::{Op};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
@ -10,10 +10,12 @@ use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
|
||||
use nalgebra_sparse::pattern::SparsityPattern;
|
||||
|
||||
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
||||
use nalgebra::proptest::matrix;
|
||||
use nalgebra::proptest::{matrix, vector};
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use matrixcompare::prop_assert_matrix_eq;
|
||||
|
||||
use std::panic::catch_unwind;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -259,6 +261,36 @@ fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value=SpmmCscArgs<i32>> {
|
||||
})
|
||||
}
|
||||
|
||||
fn csc_invertible_diagonal() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||
let non_zero_values = value_strategy::<f64>()
|
||||
.prop_filter("Only non-zeros values accepted", |x| x != &0.0);
|
||||
|
||||
vector(non_zero_values, PROPTEST_MATRIX_DIM)
|
||||
.prop_map(|d| {
|
||||
let mut matrix = CscMatrix::identity(d.len());
|
||||
matrix.values_mut().clone_from_slice(&d.as_slice());
|
||||
matrix
|
||||
})
|
||||
}
|
||||
|
||||
fn csc_square_with_non_zero_diagonals() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||
csc_invertible_diagonal()
|
||||
.prop_flat_map(|d| {
|
||||
csc(value_strategy::<f64>(), Just(d.nrows()), Just(d.nrows()), PROPTEST_MAX_NNZ)
|
||||
.prop_map(move |mut c| {
|
||||
for (i, j, v) in c.triplet_iter_mut() {
|
||||
if i == j {
|
||||
*v = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the sum of a matrix with zero diagonals and an invertible diagonal
|
||||
// matrix
|
||||
c + &d
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to help us call dense GEMM with our `Op` type
|
||||
fn dense_gemm<'a>(beta: i32,
|
||||
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||
@ -1115,4 +1147,38 @@ proptest! {
|
||||
prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_solve_lower_triangular_no_transpose(
|
||||
// A CSC matrix `a` and a dimensionally compatible dense matrix `b`
|
||||
(a, b)
|
||||
in csc_square_with_non_zero_diagonals()
|
||||
.prop_flat_map(|a| {
|
||||
let nrows = a.nrows();
|
||||
(Just(a), matrix(value_strategy::<f64>(), nrows, PROPTEST_MATRIX_DIM))
|
||||
}))
|
||||
{
|
||||
let mut x = b.clone();
|
||||
spsolve_csc_lower_triangular(Op::NoOp(&a), &mut x).unwrap();
|
||||
|
||||
let a_lower = a.lower_triangle();
|
||||
prop_assert_matrix_eq!(&a_lower * &x, &b, comp = abs, tol = 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_solve_lower_triangular_transpose(
|
||||
// A CSC matrix `a` and a dimensionally compatible dense matrix `b` (with a transposed)
|
||||
(a, b)
|
||||
in csc_square_with_non_zero_diagonals()
|
||||
.prop_flat_map(|a| {
|
||||
let ncols = a.ncols();
|
||||
(Just(a), matrix(value_strategy::<f64>(), ncols, PROPTEST_MATRIX_DIM))
|
||||
}))
|
||||
{
|
||||
let mut x = b.clone();
|
||||
spsolve_csc_lower_triangular(Op::Transpose(&a), &mut x).unwrap();
|
||||
|
||||
let a_lower = a.lower_triangle();
|
||||
prop_assert_matrix_eq!(&a_lower.transpose() * &x, &b, comp = abs, tol = 1e-6);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user