diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index e6a072de..45d232f2 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -195,7 +195,7 @@ where pub fn insert_column( self, j: usize, - c: &Matrix, + col: &Matrix, ) -> Cholesky> where D: DimAdd, @@ -203,7 +203,7 @@ where S2: Storage, ShapeConstraint: SameNumberOfRows>, { - let n = c.nrows(); + let n = col.nrows(); assert_eq!( n, self.chol.nrows() + 1, @@ -211,10 +211,26 @@ where ); assert!(j < n, "j needs to be within the bound of the new matrix."); // TODO what is the fastest way to produce the new matrix ? - let chol= self.chol.clone().insert_column(j, N::zero()).insert_row(j, N::zero()); + // TODO check for adjoint problems + let mut chol= self.chol.clone().insert_column(j, N::zero()).insert_row(j, N::zero()); + + // update the top center element S12 + let top_left_corner = chol.slice_range(..j-1, ..j-1); + let colj = col.rows_range(..j-1); // clone_owned needed to get storage mut for b in solve + let new_colj = top_left_corner.ad_solve_lower_triangular(&colj).unwrap(); + chol.slice_range_mut(..j-1, j).copy_from(&new_colj); + + // update the center element S22 + let rowj = chol.slice_range(j, ..j-1); + let center_element = N::sqrt(col[j] + rowj.dot(&rowj.adjoint())); // TODO is there a better way to multiply a vector by its adjoint ? norm_squared ? + chol[(j,j)] = center_element; + + // update the right center element S23 + //chol.slice_range_mut(j+1.., j).copy_from(&new_rowj); + + // update the bottom right corner // TODO see https://en.wikipedia.org/wiki/Cholesky_decomposition#Updating_the_decomposition - unimplemented!(); Cholesky { chol } } @@ -234,7 +250,7 @@ where // TODO what is the fastest way to produce the new matrix ? let mut chol= self.chol.clone().remove_column(j).remove_row(j); - // updates the corner + // updates the bottom right corner let mut corner = chol.slice_range_mut(j.., j..); let colj = self.chol.slice_range(j+1.., j); rank_one_update_helper(&mut corner, &colj, N::real(N::one())); diff --git a/src/linalg/solve.rs b/src/linalg/solve.rs index f10b1d00..a6b9196f 100644 --- a/src/linalg/solve.rs +++ b/src/linalg/solve.rs @@ -15,7 +15,7 @@ impl> SquareMatrix { b: &Matrix, ) -> Option> where - S2: StorageMut, + S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { @@ -35,7 +35,7 @@ impl> SquareMatrix { b: &Matrix, ) -> Option> where - S2: StorageMut, + S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { @@ -191,7 +191,7 @@ impl> SquareMatrix { b: &Matrix, ) -> Option> where - S2: StorageMut, + S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { @@ -211,7 +211,7 @@ impl> SquareMatrix { b: &Matrix, ) -> Option> where - S2: StorageMut, + S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { @@ -273,7 +273,7 @@ impl> SquareMatrix { b: &Matrix, ) -> Option> where - S2: StorageMut, + S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { @@ -293,7 +293,7 @@ impl> SquareMatrix { b: &Matrix, ) -> Option> where - S2: StorageMut, + S2: Storage, DefaultAllocator: Allocator, ShapeConstraint: SameNumberOfRows, { diff --git a/tests/linalg/cholesky.rs b/tests/linalg/cholesky.rs index aa411564..e3e5fdc7 100644 --- a/tests/linalg/cholesky.rs +++ b/tests/linalg/cholesky.rs @@ -99,6 +99,26 @@ macro_rules! gen_tests( relative_eq!(m, m_chol_updated, epsilon = 1.0e-7) } + fn cholesky_insert_column(n: usize) -> bool { + let n = n.max(1).min(5); + let j = random::() % n; + let m_updated = RandomSDP::new(Dynamic::new(n), || random::<$scalar>().0).unwrap(); + + // build m and col from m_updated + let col = m_updated.column(j); + let m = m_updated.clone().remove_column(j).remove_row(j); + + // remove column from cholesky decomposition and rebuild m + let chol = m.clone().cholesky().unwrap().insert_column(j, &col); + let m_chol_updated = chol.l() * chol.l().adjoint(); + + println!("n={} j={}", n, j); + println!("chol updated:{}", m_chol_updated); + println!("m updated:{}", m_updated); + + relative_eq!(m_updated, m_chol_updated, epsilon = 1.0e-7) + } + fn cholesky_remove_column(n: usize) -> bool { let n = n.max(1).min(5); let j = random::() % n; @@ -111,10 +131,6 @@ macro_rules! gen_tests( // remove column from m let m_updated = m.remove_column(j).remove_row(j); - println!("n={} j={}", n, j); - println!("chol:{}", m_chol_updated); - println!("m up:{}", m_updated); - relative_eq!(m_updated, m_chol_updated, epsilon = 1.0e-7) } }