From cfa7bbdc7c4e0ee65f75425380df3708381a6867 Mon Sep 17 00:00:00 2001 From: Nestor Demeure Date: Sun, 3 Nov 2019 14:33:35 +0100 Subject: [PATCH] remove column is now working --- src/linalg/cholesky.rs | 57 +++++++++++++++++++++++++++++++++++++--- tests/linalg/cholesky.rs | 19 ++++++++++++++ 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index f63ab826..e6a072de 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -211,7 +211,7 @@ where ); assert!(j < n, "j needs to be within the bound of the new matrix."); // TODO what is the fastest way to produce the new matrix ? - let chol= self.chol.insert_column(j, N::zero()).insert_row(j, N::zero()); + let chol= self.chol.clone().insert_column(j, N::zero()).insert_row(j, N::zero()); // TODO see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition unimplemented!(); @@ -229,12 +229,16 @@ where DefaultAllocator: Reallocator> + Reallocator, DimDiff, DimDiff>, { let n = self.chol.nrows(); + assert!(n > 0, "The matrix needs at least one column."); assert!(j < n, "j needs to be within the bound of the matrix."); // TODO what is the fastest way to produce the new matrix ? - let chol= self.chol.remove_column(j).remove_row(j); + let mut chol= self.chol.clone().remove_column(j).remove_row(j); + + // updates the corner + let mut corner = chol.slice_range_mut(j.., j..); + let colj = self.chol.slice_range(j+1.., j); + rank_one_update_helper(&mut corner, &colj, N::real(N::one())); - // TODO see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition - unimplemented!(); Cholesky { chol } } } @@ -251,3 +255,48 @@ where Cholesky::new(self.into_owned()) } } + +/// Given the Cholesky decomposition of a matrix `M`, a scalar `sigma` and a vector `v`, +/// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v.adjoint()`. +fn rank_one_update_helper(chol : &mut Matrix, x: &Matrix, sigma: N::RealField) + where + N: ComplexField, D: DimSub, R2: Dim, + S: StorageMut, + S2: Storage, + DefaultAllocator: Allocator + Allocator, + ShapeConstraint: SameNumberOfRows, +{ + // heavily inspired by Eigen's `llt_rank_update_lower` implementation https://eigen.tuxfamily.org/dox/LLT_8h_source.html + let n = x.nrows(); + assert_eq!( + n, + chol.nrows(), + "The input vector must be of the same size as the factorized matrix." + ); + let mut x = x.clone_owned(); + let mut beta = crate::one::(); + for j in 0..n { + // updates the diagonal + let diag = N::real(unsafe { *chol.get_unchecked((j, j)) }); + let diag2 = diag * diag; + let xj = unsafe { *x.get_unchecked(j) }; + let sigma_xj2 = sigma * N::modulus_squared(xj); + let gamma = diag2 * beta + sigma_xj2; + let new_diag = (diag2 + sigma_xj2 / beta).sqrt(); + unsafe { *chol.get_unchecked_mut((j, j)) = N::from_real(new_diag) }; + beta += sigma_xj2 / diag2; + // updates the terms of L + let mut xjplus = x.rows_range_mut(j + 1..); + let mut col_j = chol.slice_range_mut(j + 1.., j); + // temp_jplus -= (wj / N::from_real(diag)) * col_j; + xjplus.axpy(-xj / N::from_real(diag), &col_j, N::one()); + if gamma != crate::zero::() { + // col_j = N::from_real(nljj / diag) * col_j + (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp_jplus; + col_j.axpy( + N::from_real(new_diag * sigma / gamma) * N::conjugate(xj), + &xjplus, + N::from_real(new_diag / diag), + ); + } + } +} \ No newline at end of file diff --git a/tests/linalg/cholesky.rs b/tests/linalg/cholesky.rs index ea8402a3..aa411564 100644 --- a/tests/linalg/cholesky.rs +++ b/tests/linalg/cholesky.rs @@ -98,6 +98,25 @@ macro_rules! gen_tests( relative_eq!(m, m_chol_updated, epsilon = 1.0e-7) } + + fn cholesky_remove_column(n: usize) -> bool { + let n = n.max(1).min(5); + let j = random::() % n; + let m = RandomSDP::new(Dynamic::new(n), || random::<$scalar>().0).unwrap(); + + // remove column from cholesky decomposition and rebuild m + let chol = m.clone().cholesky().unwrap().remove_column(j); + let m_chol_updated = chol.l() * chol.l().adjoint(); + + // remove column from m + let m_updated = m.remove_column(j).remove_row(j); + + println!("n={} j={}", n, j); + println!("chol:{}", m_chol_updated); + println!("m up:{}", m_updated); + + relative_eq!(m_updated, m_chol_updated, epsilon = 1.0e-7) + } } } }