diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index da3549d7..cf338a95 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; use num::One; use simba::scalar::ComplexField; +use simba::simd::SimdComplexField; use crate::allocator::Allocator; use crate::base::{DefaultAllocator, Matrix, MatrixMN, MatrixN, SquareMatrix, Vector}; @@ -23,29 +24,26 @@ use crate::storage::{Storage, StorageMut}; MatrixN: Deserialize<'de>")) )] #[derive(Clone, Debug)] -pub struct Cholesky -where - DefaultAllocator: Allocator, +pub struct Cholesky +where DefaultAllocator: Allocator { chol: MatrixN, } -impl Copy for Cholesky +impl Copy for Cholesky where DefaultAllocator: Allocator, MatrixN: Copy, { } -impl> Cholesky -where - DefaultAllocator: Allocator, +impl Cholesky +where DefaultAllocator: Allocator { - /// Attempts to compute the Cholesky decomposition of `matrix`. + /// Computes the Cholesky decomposition of `matrix` without checking that the matrix is definite-positive. /// - /// Returns `None` if the input matrix is not definite-positive. The input matrix is assumed - /// to be symmetric and only the lower-triangular part is read. - pub fn new(mut matrix: MatrixN) -> Option { + /// If the input matrix is not definite-positive, the decomposition may contain trash values (Inf, NaN, etc.) + pub fn new_unchecked(mut matrix: MatrixN) -> Self { assert!(matrix.is_square(), "The input matrix must be square."); let n = matrix.nrows(); @@ -57,29 +55,21 @@ where let (mut col_j, col_k) = matrix.columns_range_pair_mut(j, k); let mut col_j = col_j.rows_range_mut(j..); let col_k = col_k.rows_range(j..); - - col_j.axpy(factor.conjugate(), &col_k, N::one()); + col_j.axpy(factor.simd_conjugate(), &col_k, N::one()); } let diag = unsafe { *matrix.get_unchecked((j, j)) }; - if !diag.is_zero() { - if let Some(denom) = diag.try_sqrt() { - unsafe { - *matrix.get_unchecked_mut((j, j)) = denom; - } + let denom = diag.simd_sqrt(); - let mut col = matrix.slice_range_mut(j + 1.., j); - col /= denom; - continue; - } + unsafe { + *matrix.get_unchecked_mut((j, j)) = denom; } - // The diagonal element is either zero or its square root could not - // be taken (e.g. for negative real numbers). - return None; + let mut col = matrix.slice_range_mut(j + 1.., j); + col /= denom; } - Some(Cholesky { chol: matrix }) + Cholesky { chol: matrix } } /// Retrieves the lower-triangular factor of the Cholesky decomposition with its strictly @@ -121,8 +111,8 @@ where S2: StorageMut, ShapeConstraint: SameNumberOfRows, { - let _ = self.chol.solve_lower_triangular_mut(b); - let _ = self.chol.ad_solve_lower_triangular_mut(b); + self.chol.solve_lower_triangular_unchecked_mut(b); + self.chol.ad_solve_lower_triangular_unchecked_mut(b); } /// Returns the solution of the system `self * x = b` where `self` is the decomposed matrix and @@ -146,6 +136,51 @@ where self.solve_mut(&mut res); res } +} + +impl Cholesky +where DefaultAllocator: Allocator +{ + /// Attempts to compute the Cholesky decomposition of `matrix`. + /// + /// Returns `None` if the input matrix is not definite-positive. The input matrix is assumed + /// to be symmetric and only the lower-triangular part is read. + pub fn new(mut matrix: MatrixN) -> Option { + assert!(matrix.is_square(), "The input matrix must be square."); + + let n = matrix.nrows(); + + for j in 0..n { + for k in 0..j { + let factor = unsafe { -*matrix.get_unchecked((j, k)) }; + + let (mut col_j, col_k) = matrix.columns_range_pair_mut(j, k); + let mut col_j = col_j.rows_range_mut(j..); + let col_k = col_k.rows_range(j..); + + col_j.axpy(factor.conjugate(), &col_k, N::one()); + } + + let diag = unsafe { *matrix.get_unchecked((j, j)) }; + if !diag.is_zero() { + if let Some(denom) = diag.try_sqrt() { + unsafe { + *matrix.get_unchecked_mut((j, j)) = denom; + } + + let mut col = matrix.slice_range_mut(j + 1.., j); + col /= denom; + continue; + } + } + + // The diagonal element is either zero or its square root could not + // be taken (e.g. for negative real numbers). + return None; + } + + Some(Cholesky { chol: matrix }) + } /// Given the Cholesky decomposition of a matrix `M`, a scalar `sigma` and a vector `v`, /// performs a rank one update such that we end up with the decomposition of `M + sigma * (v * v.adjoint())`. @@ -327,8 +362,7 @@ where } impl, S: Storage> SquareMatrix -where - DefaultAllocator: Allocator, +where DefaultAllocator: Allocator { /// Attempts to compute the Cholesky decomposition of this matrix. /// diff --git a/src/linalg/solve.rs b/src/linalg/solve.rs index 56db4ade..ac5bff46 100644 --- a/src/linalg/solve.rs +++ b/src/linalg/solve.rs @@ -1,4 +1,5 @@ use simba::scalar::ComplexField; +use simba::simd::SimdComplexField; use crate::base::allocator::Allocator; use crate::base::constraint::{SameNumberOfRows, ShapeConstraint}; @@ -432,3 +433,336 @@ impl> SquareMatrix { true } } + +/* + * + * SIMD-compatible unchecked versions. + * + */ + +impl> SquareMatrix { + /// Computes the solution of the linear system `self . x = b` where `x` is the unknown and only + /// the lower-triangular part of `self` (including the diagonal) is considered not-zero. + #[inline] + pub fn solve_lower_triangular_unchecked( + &self, + b: &Matrix, + ) -> MatrixMN + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut res = b.clone_owned(); + self.solve_lower_triangular_unchecked_mut(&mut res); + res + } + + /// Computes the solution of the linear system `self . x = b` where `x` is the unknown and only + /// the upper-triangular part of `self` (including the diagonal) is considered not-zero. + #[inline] + pub fn solve_upper_triangular_unchecked( + &self, + b: &Matrix, + ) -> MatrixMN + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut res = b.clone_owned(); + self.solve_upper_triangular_unchecked_mut(&mut res); + res + } + + /// Solves the linear system `self . x = b` where `x` is the unknown and only the + /// lower-triangular part of `self` (including the diagonal) is considered not-zero. + pub fn solve_lower_triangular_unchecked_mut( + &self, + b: &mut Matrix, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..b.ncols() { + self.solve_lower_triangular_vector_unchecked_mut(&mut b.column_mut(i)); + } + } + + fn solve_lower_triangular_vector_unchecked_mut(&self, b: &mut Vector) + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + let dim = self.nrows(); + + for i in 0..dim { + let coeff; + + unsafe { + let diag = *self.get_unchecked((i, i)); + coeff = *b.vget_unchecked(i) / diag; + *b.vget_unchecked_mut(i) = coeff; + } + + b.rows_range_mut(i + 1..) + .axpy(-coeff, &self.slice_range(i + 1.., i), N::one()); + } + } + + // FIXME: add the same but for solving upper-triangular. + /// Solves the linear system `self . x = b` where `x` is the unknown and only the + /// lower-triangular part of `self` is considered not-zero. The diagonal is never read as it is + /// assumed to be equal to `diag`. Returns `false` and does not modify its inputs if `diag` is zero. + pub fn solve_lower_triangular_with_diag_unchecked_mut( + &self, + b: &mut Matrix, + diag: N, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + let dim = self.nrows(); + let cols = b.ncols(); + + for k in 0..cols { + let mut bcol = b.column_mut(k); + + for i in 0..dim - 1 { + let coeff = unsafe { *bcol.vget_unchecked(i) } / diag; + bcol.rows_range_mut(i + 1..) + .axpy(-coeff, &self.slice_range(i + 1.., i), N::one()); + } + } + } + + /// Solves the linear system `self . x = b` where `x` is the unknown and only the + /// upper-triangular part of `self` (including the diagonal) is considered not-zero. + pub fn solve_upper_triangular_unchecked_mut( + &self, + b: &mut Matrix, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..b.ncols() { + self.solve_upper_triangular_vector_unchecked_mut(&mut b.column_mut(i)) + } + } + + fn solve_upper_triangular_vector_unchecked_mut(&self, b: &mut Vector) + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + let dim = self.nrows(); + + for i in (0..dim).rev() { + let coeff; + + unsafe { + let diag = *self.get_unchecked((i, i)); + coeff = *b.vget_unchecked(i) / diag; + *b.vget_unchecked_mut(i) = coeff; + } + + b.rows_range_mut(..i) + .axpy(-coeff, &self.slice_range(..i, i), N::one()); + } + } + + /* + * + * Transpose and adjoint versions + * + */ + /// Computes the solution of the linear system `self.transpose() . x = b` where `x` is the unknown and only + /// the lower-triangular part of `self` (including the diagonal) is considered not-zero. + #[inline] + pub fn tr_solve_lower_triangular_unchecked( + &self, + b: &Matrix, + ) -> MatrixMN + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut res = b.clone_owned(); + self.tr_solve_lower_triangular_unchecked_mut(&mut res); + res + } + + /// Computes the solution of the linear system `self.transpose() . x = b` where `x` is the unknown and only + /// the upper-triangular part of `self` (including the diagonal) is considered not-zero. + #[inline] + pub fn tr_solve_upper_triangular_unchecked( + &self, + b: &Matrix, + ) -> MatrixMN + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut res = b.clone_owned(); + self.tr_solve_upper_triangular_unchecked_mut(&mut res); + res + } + + /// Solves the linear system `self.transpose() . x = b` where `x` is the unknown and only the + /// lower-triangular part of `self` (including the diagonal) is considered not-zero. + pub fn tr_solve_lower_triangular_unchecked_mut( + &self, + b: &mut Matrix, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..b.ncols() { + self.xx_solve_lower_triangular_vector_unchecked_mut( + &mut b.column_mut(i), + |e| e, + |a, b| a.dot(b), + ) + } + } + + /// Solves the linear system `self.transpose() . x = b` where `x` is the unknown and only the + /// upper-triangular part of `self` (including the diagonal) is considered not-zero. + pub fn tr_solve_upper_triangular_unchecked_mut( + &self, + b: &mut Matrix, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..b.ncols() { + self.xx_solve_upper_triangular_vector_unchecked_mut( + &mut b.column_mut(i), + |e| e, + |a, b| a.dot(b), + ) + } + } + + /// Computes the solution of the linear system `self.adjoint() . x = b` where `x` is the unknown and only + /// the lower-triangular part of `self` (including the diagonal) is considered not-zero. + #[inline] + pub fn ad_solve_lower_triangular_unchecked( + &self, + b: &Matrix, + ) -> MatrixMN + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut res = b.clone_owned(); + self.ad_solve_lower_triangular_unchecked_mut(&mut res); + res + } + + /// Computes the solution of the linear system `self.adjoint() . x = b` where `x` is the unknown and only + /// the upper-triangular part of `self` (including the diagonal) is considered not-zero. + #[inline] + pub fn ad_solve_upper_triangular_unchecked( + &self, + b: &Matrix, + ) -> MatrixMN + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut res = b.clone_owned(); + self.ad_solve_upper_triangular_unchecked_mut(&mut res); + res + } + + /// Solves the linear system `self.adjoint() . x = b` where `x` is the unknown and only the + /// lower-triangular part of `self` (including the diagonal) is considered not-zero. + pub fn ad_solve_lower_triangular_unchecked_mut( + &self, + b: &mut Matrix, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..b.ncols() { + self.xx_solve_lower_triangular_vector_unchecked_mut( + &mut b.column_mut(i), + |e| e.simd_conjugate(), + |a, b| a.dotc(b), + ) + } + } + + /// Solves the linear system `self.adjoint() . x = b` where `x` is the unknown and only the + /// upper-triangular part of `self` (including the diagonal) is considered not-zero. + pub fn ad_solve_upper_triangular_unchecked_mut( + &self, + b: &mut Matrix, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..b.ncols() { + self.xx_solve_upper_triangular_vector_unchecked_mut( + &mut b.column_mut(i), + |e| e.simd_conjugate(), + |a, b| a.dotc(b), + ) + } + } + + #[inline(always)] + fn xx_solve_lower_triangular_vector_unchecked_mut( + &self, + b: &mut Vector, + conjugate: impl Fn(N) -> N, + dot: impl Fn( + &DVectorSlice, + &DVectorSlice, + ) -> N, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + let dim = self.nrows(); + + for i in (0..dim).rev() { + let dot = dot(&self.slice_range(i + 1.., i), &b.slice_range(i + 1.., 0)); + + unsafe { + let b_i = b.vget_unchecked_mut(i); + let diag = conjugate(*self.get_unchecked((i, i))); + *b_i = (*b_i - dot) / diag; + } + } + } + + #[inline(always)] + fn xx_solve_upper_triangular_vector_unchecked_mut( + &self, + b: &mut Vector, + conjugate: impl Fn(N) -> N, + dot: impl Fn( + &DVectorSlice, + &DVectorSlice, + ) -> N, + ) where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + for i in 0..self.nrows() { + let dot = dot(&self.slice_range(..i, i), &b.slice_range(..i, 0)); + + unsafe { + let b_i = b.vget_unchecked_mut(i); + let diag = conjugate(*self.get_unchecked((i, i))); + *b_i = (*b_i - dot) / diag; + } + } + } +}