code cleaned
This commit is contained in:
parent
3ae88127ee
commit
3123da5529
|
@ -147,7 +147,7 @@ where
|
|||
}
|
||||
|
||||
/// Given the Cholesky decomposition of a matrix `M`, a scalar `sigma` and a vector `v`,
|
||||
/// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v^*`.
|
||||
/// performs a rank one update such that we end up with the decomposition of `M + sigma * v*v.adjoint()`.
|
||||
pub fn rank_one_update<R2: Dim, S2>(&mut self, x: &Matrix<N, R2, U1, S2>, sigma: N::RealField)
|
||||
where
|
||||
S2: Storage<N, R2, U1>,
|
||||
|
@ -156,27 +156,31 @@ where
|
|||
{
|
||||
// for a description of the operation, see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition
|
||||
// heavily inspired by Eigen's implementation https://eigen.tuxfamily.org/dox/LLT_8h_source.html
|
||||
// TODO use unsafe { *matrix.get_unchecked((j, j)) }
|
||||
let n = x.nrows();
|
||||
let mut temp = x.clone_owned();
|
||||
let mut x = x.clone_owned();
|
||||
let mut beta = crate::one::<N::RealField>();
|
||||
for j in 0..n {
|
||||
let ljj = N::real(self.chol[(j, j)]);
|
||||
let dj = ljj * ljj;
|
||||
let wj = temp[j];
|
||||
let swj2 = sigma * N::modulus_squared(wj);
|
||||
let gamma = dj * beta + swj2;
|
||||
let nljj = (dj + swj2 / beta).sqrt();
|
||||
self.chol[(j, j)] = N::from_real(nljj);
|
||||
beta += swj2 / dj;
|
||||
let diag = N::real(unsafe { *self.chol.get_unchecked((j, j)) });
|
||||
let diag2 = diag * diag;
|
||||
let xj = unsafe { *x.get_unchecked(j) };
|
||||
let sigma_xj2 = sigma * N::modulus_squared(xj);
|
||||
let gamma = diag2 * beta + sigma_xj2;
|
||||
let new_diag = (diag2 + sigma_xj2 / beta).sqrt();
|
||||
unsafe { *self.chol.get_unchecked_mut((j, j)) = N::from_real(new_diag) };
|
||||
beta += sigma_xj2 / diag2;
|
||||
// Update the terms of L
|
||||
if j < n {
|
||||
for k in (j + 1)..n {
|
||||
temp[k] -= (wj / N::from_real(ljj)) * self.chol[(k, j)];
|
||||
let mut xjplus = x.rows_range_mut(j + 1..);
|
||||
let mut col_j = self.chol.slice_range_mut(j + 1.., j);
|
||||
// temp_jplus -= (wj / N::from_real(diag)) * col_j;
|
||||
xjplus.axpy(-xj / N::from_real(diag), &col_j, N::one());
|
||||
if gamma != crate::zero::<N::RealField>() {
|
||||
self.chol[(k, j)] = N::from_real(nljj / ljj) * self.chol[(k, j)]
|
||||
+ (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp[k];
|
||||
}
|
||||
// col_j = N::from_real(nljj / diag) * col_j + (N::from_real(nljj * sigma / gamma) * N::conjugate(wj)) * temp_jplus;
|
||||
col_j.axpy(
|
||||
N::from_real(new_diag * sigma / gamma) * N::conjugate(xj),
|
||||
&xjplus,
|
||||
N::from_real(new_diag / diag),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,13 +82,13 @@ macro_rules! gen_tests(
|
|||
let mut m = RandomSDP::new(U4, || random::<$scalar>().0).unwrap();
|
||||
let x = Vector4::<$scalar>::new_random().map(|e| e.0);
|
||||
|
||||
// TODO this is dirty but $scalar appears to not be a scalar type in this file
|
||||
// this is dirty but $scalar is not a scalar type (its a Rand) in this file
|
||||
let zero = random::<$scalar>().0 * 0.;
|
||||
let one = zero + 1.;
|
||||
let sigma = random::<f64>(); // needs to be a real
|
||||
let sigma_scalar = zero + sigma;
|
||||
|
||||
// updates cholesky decomposition and reconstructs m
|
||||
// updates cholesky decomposition and reconstructs m updated
|
||||
let mut chol = m.clone().cholesky().unwrap();
|
||||
chol.rank_one_update(&x, sigma);
|
||||
let m_chol_updated = chol.l() * chol.l().adjoint();
|
||||
|
|
Loading…
Reference in New Issue