diff --git a/src/sparse/cs_matrix.rs b/src/sparse/cs_matrix.rs index 2666425a..43e77863 100644 --- a/src/sparse/cs_matrix.rs +++ b/src/sparse/cs_matrix.rs @@ -6,9 +6,9 @@ use std::ops::{Add, Mul, Range}; use std::slice; use allocator::Allocator; -use constraint::{AreMultipliable, DimEq, ShapeConstraint}; +use constraint::{AreMultipliable, DimEq, ShapeConstraint, SameNumberOfRows}; use storage::{Storage, StorageMut}; -use {DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; +use {Real, DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; // FIXME: this structure exists for now only because impl trait // cannot be used for trait method return types. @@ -245,6 +245,132 @@ impl> CsMatrix { } } +impl> CsMatrix { + pub fn solve_lower_triangular( + &self, + b: &Matrix, + ) -> Option> + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut b = b.clone_owned(); + if self.solve_lower_triangular_mut(&mut b) { + Some(b) + } else { + None + } + } + + pub fn tr_solve_lower_triangular( + &self, + b: &Matrix, + ) -> Option> + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut b = b.clone_owned(); + if self.tr_solve_lower_triangular_mut(&mut b) { + Some(b) + } else { + None + } + } + + pub fn solve_lower_triangular_mut( + &self, + b: &mut Matrix, + ) -> bool + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + let (nrows, ncols) = self.data.shape(); + assert_eq!(nrows.value(), ncols.value(), "The matrix must be square."); + assert_eq!(nrows.value(), b.len(), "Mismatched matrix dimensions."); + + for j2 in 0..b.ncols() { + let mut b = b.column_mut(j2); + + for j in 0..ncols.value() { + let mut column = self.data.column_entries(j); + let mut diag_found = false; + + while let Some((i, val)) = column.next() { + if i == j { + if val.is_zero() { + return false; + } + + b[j] /= val; + diag_found = true; + break; + } + } + + if !diag_found { + return false; + } + + for (i, val) in column { + b[i] -= b[j] * val; + } + } + } + + true + } + + + pub fn tr_solve_lower_triangular_mut( + &self, + b: &mut Matrix, + ) -> bool + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, + { + let (nrows, ncols) = self.data.shape(); + assert_eq!(nrows.value(), ncols.value(), "The matrix must be square."); + assert_eq!(nrows.value(), b.len(), "Mismatched matrix dimensions."); + + for j2 in 0..b.ncols() { + let mut b = b.column_mut(j2); + + for j in (0..ncols.value()).rev() { + let mut column = self.data.column_entries(j); + let mut diag = None; + + while let Some((i, val)) = column.next() { + if i == j { + if val.is_zero() { + return false; + } + + diag = Some(val); + break; + } + } + + if let Some(diag) = diag { + for (i, val) in column { + b[j] -= val * b[i]; + } + + b[j] /= diag; + } else { + return false; + } + } + } + + true + } +} + /* impl CsVector { pub fn axpy(&mut self, alpha: N, x: CsVector, beta: N) {