code cleaned

This commit is contained in:
Nestor Demeure 2019-11-02 19:04:07 +01:00
parent 3ae88127ee
commit 3123da5529
2 changed files with 23 additions and 19 deletions

View File

@ -147,7 +147,7 @@ where
} }
/// Given the Cholesky decomposition of a matrix `M`, a scalar `sigma` and a vector `v`, /// Given the Cholesky decomposition of a matrix `M`, a scalar `sigma` and a vector `v`,
/// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v^*`. /// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v.adjoint()`.
pub fn rank_one_update<R2: Dim, S2>(&mut self, x: &Matrix<N, R2, U1, S2>, sigma: N::RealField) pub fn rank_one_update<R2: Dim, S2>(&mut self, x: &Matrix<N, R2, U1, S2>, sigma: N::RealField)
where where
S2: Storage<N, R2, U1>, S2: Storage<N, R2, U1>,
@ -156,27 +156,31 @@ where
{ {
// for a description of the operation, see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition // for a description of the operation, see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition
// heavily inspired by Eigen's implementation https://eigen.tuxfamily.org/dox/LLT_8h_source.html // heavily inspired by Eigen's implementation https://eigen.tuxfamily.org/dox/LLT_8h_source.html
// TODO use unsafe { *matrix.get_unchecked((j, j)) }
let n = x.nrows(); let n = x.nrows();
let mut temp = x.clone_owned(); let mut x = x.clone_owned();
let mut beta = crate::one::<N::RealField>(); let mut beta = crate::one::<N::RealField>();
for j in 0..n { for j in 0..n {
let ljj = N::real(self.chol[(j, j)]); let diag = N::real(unsafe { *self.chol.get_unchecked((j, j)) });
let dj = ljj * ljj; let diag2 = diag * diag;
let wj = temp[j]; let xj = unsafe { *x.get_unchecked(j) };
let swj2 = sigma * N::modulus_squared(wj); let sigma_xj2 = sigma * N::modulus_squared(xj);
let gamma = dj * beta + swj2; let gamma = diag2 * beta + sigma_xj2;
let nljj = (dj + swj2 / beta).sqrt(); let new_diag = (diag2 + sigma_xj2 / beta).sqrt();
self.chol[(j, j)] = N::from_real(nljj); unsafe { *self.chol.get_unchecked_mut((j, j)) = N::from_real(new_diag) };
beta += swj2 / dj; beta += sigma_xj2 / diag2;
// Update the terms of L // Update the terms of L
if j < n { if j < n {
for k in (j + 1)..n { let mut xjplus = x.rows_range_mut(j + 1..);
temp[k] -= (wj / N::from_real(ljj)) * self.chol[(k, j)]; let mut col_j = self.chol.slice_range_mut(j + 1.., j);
// temp_jplus -= (wj / N::from_real(diag)) * col_j;
xjplus.axpy(-xj / N::from_real(diag), &col_j, N::one());
if gamma != crate::zero::<N::RealField>() { if gamma != crate::zero::<N::RealField>() {
self.chol[(k, j)] = N::from_real(nljj / ljj) * self.chol[(k, j)] // col_j = N::from_real(nljj / diag) * col_j + (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp_jplus;
+ (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp[k]; col_j.axpy(
} N::from_real(new_diag * sigma / gamma) * N::conjugate(xj),
&xjplus,
N::from_real(new_diag / diag),
);
} }
} }
} }

View File

@ -82,13 +82,13 @@ macro_rules! gen_tests(
let mut m = RandomSDP::new(U4, || random::<$scalar>().0).unwrap(); let mut m = RandomSDP::new(U4, || random::<$scalar>().0).unwrap();
let x = Vector4::<$scalar>::new_random().map(|e| e.0); let x = Vector4::<$scalar>::new_random().map(|e| e.0);
// TODO this is dirty but $scalar appears to not be a scalar type in this file // this is dirty but $scalar is not a scalar type (its a Rand) in this file
let zero = random::<$scalar>().0 * 0.; let zero = random::<$scalar>().0 * 0.;
let one = zero + 1.; let one = zero + 1.;
let sigma = random::<f64>(); // needs to be a real let sigma = random::<f64>(); // needs to be a real
let sigma_scalar = zero + sigma; let sigma_scalar = zero + sigma;
// updates cholesky decomposition and reconstructs m // updates cholesky decomposition and reconstructs m updated
let mut chol = m.clone().cholesky().unwrap(); let mut chol = m.clone().cholesky().unwrap();
chol.rank_one_update(&x, sigma); chol.rank_one_update(&x, sigma);
let m_chol_updated = chol.l() * chol.l().adjoint(); let m_chol_updated = chol.l() * chol.l().adjoint();