diff --git a/nalgebra-sparse/src/factorization/cholesky.rs b/nalgebra-sparse/src/factorization/cholesky.rs index 7195080f..2639424c 100644 --- a/nalgebra-sparse/src/factorization/cholesky.rs +++ b/nalgebra-sparse/src/factorization/cholesky.rs @@ -4,9 +4,11 @@ use crate::pattern::SparsityPattern; use crate::csc::CscMatrix; use core::{mem, iter}; -use nalgebra::{Scalar, RealField}; +use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix}; use std::sync::Arc; use std::fmt::{Display, Formatter}; +use crate::ops::serial::spsolve_csc_lower_triangular; +use crate::ops::Op; pub struct CscSymbolicCholesky { // Pattern of the original matrix that was decomposed @@ -177,6 +179,27 @@ impl CscCholesky { Ok(()) } + pub fn solve<'a>(&'a self, b: impl Into>) -> DMatrix { + let b = b.into(); + let mut output = b.clone_owned(); + self.solve_mut(&mut output); + output + } + + pub fn solve_mut<'a>(&'a self, b: impl Into>) + { + let expect_msg = "If the Cholesky factorization succeeded,\ + then the triangular solve should never fail"; + // Solve LY = B + let mut y = b.into(); + spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y) + .expect(expect_msg); + + // Solve L^T X = Y + let mut x = y; + spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x) + .expect(expect_msg); + } } diff --git a/nalgebra-sparse/tests/unit_tests/cholesky.rs b/nalgebra-sparse/tests/unit_tests/cholesky.rs index 87517828..82cb2c42 100644 --- a/nalgebra-sparse/tests/unit_tests/cholesky.rs +++ b/nalgebra-sparse/tests/unit_tests/cholesky.rs @@ -4,6 +4,7 @@ use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::factorization::{CscCholesky}; use nalgebra_sparse::proptest::csc; use nalgebra::{Matrix5, Vector5, Cholesky, DMatrix}; +use nalgebra::proptest::matrix; use proptest::prelude::*; use matrixcompare::{assert_matrix_eq, prop_assert_matrix_eq}; @@ -35,6 +36,30 @@ proptest! { prop_assert!(is_lower_triangular); } + #[test] + fn cholesky_solve_positive_definite( + (matrix, rhs) in positive_definite() + .prop_flat_map(|csc| { + let rhs = matrix(value_strategy::(), csc.nrows(), PROPTEST_MATRIX_DIM); + (Just(csc), rhs) + }) + ) { + let cholesky = CscCholesky::factor(&matrix).unwrap(); + + // solve_mut + { + let mut x = rhs.clone(); + cholesky.solve_mut(&mut x); + prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12); + } + + // solve + { + let x = cholesky.solve(&rhs); + prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12); + } + } + } // This is a test ported from nalgebra's "sparse" module, for the original CsCholesky impl