diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index 51da364f..f61a4e63 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -74,6 +74,14 @@ where Cholesky { chol: matrix } } + /// Uses the given matrix as-is without any checks or modifications as the + /// Cholesky decomposition. + /// + /// It is up to the user to ensure all invariants hold. + pub fn pack_dirty(matrix: OMatrix) -> Self { + Cholesky { chol: matrix } + } + /// Retrieves the lower-triangular factor of the Cholesky decomposition with its strictly /// upper-triangular part filled with zeros. pub fn unpack(mut self) -> OMatrix { @@ -163,7 +171,32 @@ where /// /// Returns `None` if the input matrix is not definite-positive. The input matrix is assumed /// to be symmetric and only the lower-triangular part is read. - pub fn new(mut matrix: OMatrix) -> Option { + pub fn new(matrix: OMatrix) -> Option { + Self::new_internal(matrix, None) + } + + /// Attempts to approximate the Cholesky decomposition of `matrix` by + /// replacing non-positive values on the diagonals during the decomposition + /// with the given `substitute`. + /// + /// [`try_sqrt`](ComplexField::try_sqrt) will be applied to the `substitute` + /// when it has to be used. + /// + /// If your input matrix results only in positive values on the diagonals + /// during the decomposition, `substitute` is unused and the result is just + /// the same as if you used [`new`](Cholesky::new). + /// + /// This method allows to compensate for matrices with very small or even + /// negative values due to numerical errors but necessarily results in only + /// an approximation: it is basically a hack. If you don't specifically need + /// Cholesky, it may be better to consider alternatives like the + /// [`LU`](crate::linalg::LU) decomposition/factorization. + pub fn new_with_substitute(matrix: OMatrix, substitute: T) -> Option { + Self::new_internal(matrix, Some(substitute)) + } + + /// Common implementation for `new` and `new_with_substitute`. + fn new_internal(mut matrix: OMatrix, substitute: Option) -> Option { assert!(matrix.is_square(), "The input matrix must be square."); let n = matrix.nrows(); @@ -179,17 +212,25 @@ where col_j.axpy(factor.conjugate(), &col_k, T::one()); } - let diag = unsafe { matrix.get_unchecked((j, j)).clone() }; - if !diag.is_zero() { - if let Some(denom) = diag.try_sqrt() { - unsafe { - *matrix.get_unchecked_mut((j, j)) = denom.clone(); - } - - let mut col = matrix.slice_range_mut(j + 1.., j); - col /= denom; - continue; + let sqrt_denom = |v: T| { + if v.is_zero() { + return None; } + v.try_sqrt() + }; + + let diag = unsafe { matrix.get_unchecked((j, j)).clone() }; + + if let Some(denom) = + sqrt_denom(diag).or_else(|| substitute.clone().and_then(sqrt_denom)) + { + unsafe { + *matrix.get_unchecked_mut((j, j)) = denom.clone(); + } + + let mut col = matrix.slice_range_mut(j + 1.., j); + col /= denom; + continue; } // The diagonal element is either zero or its square root could not diff --git a/tests/linalg/cholesky.rs b/tests/linalg/cholesky.rs index 6fd83912..891e54ca 100644 --- a/tests/linalg/cholesky.rs +++ b/tests/linalg/cholesky.rs @@ -1,5 +1,16 @@ #![cfg(all(feature = "proptest-support", feature = "debug"))] +#[test] +// #[rustfmt::skip] +fn cholesky_with_substitute() { + // Make a tiny covariance matrix with a small covariance value. + let m = na::Matrix2::new(1.0, f64::NAN, 1.0, 1e-32); + // Show that the cholesky fails for our matrix. We then try again with a substitute. + assert!(na::Cholesky::new(m).is_none()); + // ...and show that we get some result this time around. + assert!(na::Cholesky::new_with_substitute(m, 1e-8).is_some()); +} + macro_rules! gen_tests( ($module: ident, $scalar: ty) => { mod $module {