diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index c4049504..d0a9918c 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -147,7 +147,7 @@ where } /// Given the Cholesky decomposition of a matrix `M`, a scalar `sigma` and a vector `v`, - /// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v^*`. + /// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v.adjoint()`. pub fn rank_one_update(&mut self, x: &Matrix, sigma: N::RealField) where S2: Storage, @@ -156,27 +156,31 @@ where { // for a description of the operation, see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition // heavily inspired by Eigen's implementation https://eigen.tuxfamily.org/dox/LLT_8h_source.html - // TODO use unsafe { *matrix.get_unchecked((j, j)) } let n = x.nrows(); - let mut temp = x.clone_owned(); + let mut x = x.clone_owned(); let mut beta = crate::one::(); for j in 0..n { - let ljj = N::real(self.chol[(j, j)]); - let dj = ljj * ljj; - let wj = temp[j]; - let swj2 = sigma * N::modulus_squared(wj); - let gamma = dj * beta + swj2; - let nljj = (dj + swj2 / beta).sqrt(); - self.chol[(j, j)] = N::from_real(nljj); - beta += swj2 / dj; + let diag = N::real(unsafe { *self.chol.get_unchecked((j, j)) }); + let diag2 = diag * diag; + let xj = unsafe { *x.get_unchecked(j) }; + let sigma_xj2 = sigma * N::modulus_squared(xj); + let gamma = diag2 * beta + sigma_xj2; + let new_diag = (diag2 + sigma_xj2 / beta).sqrt(); + unsafe { *self.chol.get_unchecked_mut((j, j)) = N::from_real(new_diag) }; + beta += sigma_xj2 / diag2; // Update the terms of L if j < n { - for k in (j + 1)..n { - temp[k] -= (wj / N::from_real(ljj)) * self.chol[(k, j)]; - if gamma != crate::zero::() { - self.chol[(k, j)] = N::from_real(nljj / ljj) * self.chol[(k, j)] - + (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp[k]; - } + let mut xjplus = x.rows_range_mut(j + 1..); + let mut col_j = self.chol.slice_range_mut(j + 1.., j); + // temp_jplus -= (wj / N::from_real(diag)) * col_j; + xjplus.axpy(-xj / N::from_real(diag), &col_j, N::one()); + if gamma != crate::zero::() { + // col_j = N::from_real(nljj / diag) * col_j + (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp_jplus; + col_j.axpy( + N::from_real(new_diag * sigma / gamma) * N::conjugate(xj), + &xjplus, + N::from_real(new_diag / diag), + ); } } } diff --git a/tests/linalg/cholesky.rs b/tests/linalg/cholesky.rs index b04ed402..ea8402a3 100644 --- a/tests/linalg/cholesky.rs +++ b/tests/linalg/cholesky.rs @@ -82,13 +82,13 @@ macro_rules! gen_tests( let mut m = RandomSDP::new(U4, || random::<$scalar>().0).unwrap(); let x = Vector4::<$scalar>::new_random().map(|e| e.0); - // TODO this is dirty but $scalar appears to not be a scalar type in this file + // this is dirty but $scalar is not a scalar type (its a Rand) in this file let zero = random::<$scalar>().0 * 0.; let one = zero + 1.; let sigma = random::(); // needs to be a real let sigma_scalar = zero + sigma; - // updates cholesky decomposition and reconstructs m + // updates cholesky decomposition and reconstructs m updated let mut chol = m.clone().cholesky().unwrap(); chol.rank_one_update(&x, sigma); let m_chol_updated = chol.l() * chol.l().adjoint();