diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index 6a2c9da8..cbbe5ff8 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -8,6 +8,7 @@ use crate::base::{DefaultAllocator, Matrix, MatrixMN, MatrixN, SquareMatrix}; use crate::constraint::{SameNumberOfRows, ShapeConstraint}; use crate::dimension::{Dim, DimSub, Dynamic, U1}; use crate::storage::{Storage, StorageMut}; +use crate::RealField; /// The Cholesky decomposition of a symmetric-definite-positive matrix. #[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))] @@ -151,12 +152,18 @@ where /// TODO rewrite comment (current version is taken verbatim from eigen) /// TODO insures that code is correct for complex numbers, eigen uses abs2 and conj /// https://eigen.tuxfamily.org/dox/LLT_8h_source.html - pub fn rank_one_update(&mut self, x: &Matrix, sigma: N) - where + /// TODO insure that sigma is a real + pub fn rank_one_update( + &mut self, + x: &Matrix, + sigma: N2, + ) where + N: From, S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { + let sigma = ::from(sigma); let n = x.nrows(); let mut temp = x.clone_owned(); for k in 0..n { diff --git a/tests/linalg/cholesky.rs b/tests/linalg/cholesky.rs index 823ec96f..bef5de95 100644 --- a/tests/linalg/cholesky.rs +++ b/tests/linalg/cholesky.rs @@ -83,9 +83,12 @@ macro_rules! gen_tests( use nalgebra::Vector3; let mut m = RandomSDP::new(U3, || random::<$scalar>().0).unwrap(); let x = Vector3::<$scalar>::new_random().map(|e| e.0); - let mut sigma = random::<$scalar>().0; // random::<$scalar>().0; - let one = sigma*0. + 1.; // TODO this is dirty but $scalar appears to not be a scalar type in this file - sigma = one; // TODO placeholder + + // TODO this is dirty but $scalar appears to not be a scalar type in this file + let zero = random::<$scalar>().0 * 0.; + let one = zero + 1.; + let sigma = random::(); // needs to be a real + let sigma_scalar = zero + sigma; // updates cholesky decomposition and reconstructs m let mut chol = m.clone().cholesky().unwrap(); @@ -93,7 +96,7 @@ macro_rules! gen_tests( let m_chol_updated = chol.l() * chol.l().adjoint(); // updates m manually - m.ger(sigma, &x, &x, one); // m += sigma * x * x.adjoint() + m.ger(sigma_scalar, &x, &x, one); // m += sigma * x * x.adjoint() println!("sigma : {}", sigma); println!("m updated : {}", m);