diff --git a/src/sparse/cs_matrix.rs b/src/sparse/cs_matrix.rs index 43e77863..4809fb54 100644 --- a/src/sparse/cs_matrix.rs +++ b/src/sparse/cs_matrix.rs @@ -6,9 +6,9 @@ use std::ops::{Add, Mul, Range}; use std::slice; use allocator::Allocator; -use constraint::{AreMultipliable, DimEq, ShapeConstraint, SameNumberOfRows}; +use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint}; use storage::{Storage, StorageMut}; -use {Real, DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; +use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1}; // FIXME: this structure exists for now only because impl trait // cannot be used for trait method return types. @@ -20,11 +20,12 @@ pub trait CsStorageIter<'a, N, R, C = U1> { pub trait CsStorage: for<'a> CsStorageIter<'a, N, R, C> { fn shape(&self) -> (R, C); - fn nvalues(&self) -> usize; unsafe fn row_index_unchecked(&self, i: usize) -> usize; unsafe fn get_value_unchecked(&self, i: usize) -> &N; fn get_value(&self, i: usize) -> &N; fn row_index(&self, i: usize) -> usize; + fn column_range(&self, i: usize) -> Range; + fn len(&self) -> usize; } pub trait CsStorageMut: CsStorage { @@ -49,21 +50,7 @@ where vals: Vec, } -impl CsVecStorage -where - DefaultAllocator: Allocator, -{ - #[inline] - fn column_range(&self, j: usize) -> Range { - let end = if j + 1 == self.p.len() { - self.nvalues() - } else { - self.p[j + 1] - }; - - self.p[j]..end - } -} +impl CsVecStorage where DefaultAllocator: Allocator {} impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage where @@ -92,7 +79,7 @@ where } #[inline] - fn nvalues(&self) -> usize { + fn len(&self) -> usize { self.vals.len() } @@ -115,6 +102,17 @@ where fn get_value(&self, i: usize) -> &N { &self.vals[i] } + + #[inline] + fn column_range(&self, j: usize) -> Range { + let end = if j + 1 == self.p.len() { + self.len() + } else { + self.p[j + 1] + }; + + self.p[j]..end + } } /* @@ -154,7 +152,7 @@ where CsMatrix { data: CsVecStorage { shape: (nrows, ncols), - p: unsafe { VectorN::new_uninitialized_generic(ncols, U1) }, + p: VectorN::zeros_generic(ncols, U1), i, vals, }, @@ -180,8 +178,8 @@ where } impl> CsMatrix { - pub fn nvalues(&self) -> usize { - self.data.nvalues() + pub fn len(&self) -> usize { + self.data.len() } pub fn transpose(&self) -> CsMatrix @@ -190,7 +188,7 @@ impl> CsMatrix { { let (nrows, ncols) = self.data.shape(); - let nvals = self.nvalues(); + let nvals = self.len(); let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals); let mut workspace = Vector::zeros_generic(nrows, U1); @@ -250,27 +248,27 @@ impl> CsMatrix { &self, b: &Matrix, ) -> Option> - where - S2: Storage, - DefaultAllocator: Allocator, - ShapeConstraint: SameNumberOfRows, + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, { let mut b = b.clone_owned(); - if self.solve_lower_triangular_mut(&mut b) { - Some(b) - } else { - None - } + if self.solve_lower_triangular_mut(&mut b) { + Some(b) + } else { + None + } } pub fn tr_solve_lower_triangular( &self, b: &Matrix, ) -> Option> - where - S2: Storage, - DefaultAllocator: Allocator, - ShapeConstraint: SameNumberOfRows, + where + S2: Storage, + DefaultAllocator: Allocator, + ShapeConstraint: SameNumberOfRows, { let mut b = b.clone_owned(); if self.tr_solve_lower_triangular_mut(&mut b) { @@ -284,9 +282,9 @@ impl> CsMatrix { &self, b: &mut Matrix, ) -> bool - where - S2: StorageMut, - ShapeConstraint: SameNumberOfRows, + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, { let (nrows, ncols) = self.data.shape(); assert_eq!(nrows.value(), ncols.value(), "The matrix must be square."); @@ -324,14 +322,13 @@ impl> CsMatrix { true } - pub fn tr_solve_lower_triangular_mut( &self, b: &mut Matrix, ) -> bool - where - S2: StorageMut, - ShapeConstraint: SameNumberOfRows, + where + S2: StorageMut, + ShapeConstraint: SameNumberOfRows, { let (nrows, ncols) = self.data.shape(); assert_eq!(nrows.value(), ncols.value(), "The matrix must be square."); @@ -369,6 +366,106 @@ impl> CsMatrix { true } + + pub fn solve_lower_triangular_cs( + &self, + b: &CsVector, + ) -> Option> + where + S2: CsStorage, + DefaultAllocator: Allocator + Allocator + Allocator, + ShapeConstraint: SameNumberOfRows, + { + let mut reach = Vec::new(); + self.lower_triangular_reach(b, &mut reach); + let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) }; + + for i in reach.iter().cloned() { + workspace[i] = N::zero(); + } + + for (i, val) in b.data.column_entries(0) { + workspace[i] = val; + } + + for j in reach.iter().cloned().rev() { + let mut column = self.data.column_entries(j); + let mut diag_found = false; + + while let Some((i, val)) = column.next() { + if i == j { + if val.is_zero() { + break; + } + + workspace[j] /= val; + diag_found = true; + break; + } + } + + if !diag_found { + return None; + } + + for (i, val) in column { + workspace[i] -= workspace[j] * val; + } + } + + // Copy the result into a sparse vector. + let mut result = CsVector::new_uninitialized_generic(b.data.shape().0, U1, reach.len()); + + for (i, val) in reach.iter().zip(result.data.vals.iter_mut()) { + *val = workspace[*i]; + } + + result.data.i = reach; + Some(result) + } + + fn lower_triangular_reach(&self, b: &CsVector, xi: &mut Vec) + where + S2: CsStorage, + DefaultAllocator: Allocator, + { + let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false); + let mut stack = Vec::new(); + + for i in b.data.column_range(0) { + let row_index = b.data.row_index(i); + + if !visited[row_index] { + let rng = self.data.column_range(row_index); + stack.push((row_index, rng)); + self.lower_triangular_dfs(visited.as_mut_slice(), &mut stack, xi); + } + } + } + + fn lower_triangular_dfs( + &self, + visited: &mut [bool], + stack: &mut Vec<(usize, Range)>, + xi: &mut Vec, + ) { + 'recursion: while let Some((j, rng)) = stack.pop() { + visited[j] = true; + + for i in rng.clone() { + let row_id = self.data.row_index(i); + if row_id > j && !visited[row_id] { + stack.push((j, (i + 1)..rng.end)); + + let row_id = self.data.row_index(i); + stack.push((row_id, self.data.column_range(row_id))); + continue 'recursion; + } + } + + xi.push(j) + } + } } /* @@ -381,8 +478,8 @@ impl CsVector { self.data.set_column_len(0, nnzero); // Fill with the axpy. - let mut i = self.nvalues(); - let mut j = x.nvalues(); + let mut i = self.len(); + let mut j = x.len(); let mut k = nnzero - 1; let mut rid1 = self.data.row_index(0, i - 1); let mut rid2 = x.data.row_index(0, j - 1); @@ -416,7 +513,7 @@ impl> Vect ShapeConstraint: DimEq, { if beta.is_zero() { - for i in 0..x.nvalues() { + for i in 0..x.len() { unsafe { let k = x.data.row_index_unchecked(i); let y = self.vget_unchecked_mut(k); @@ -427,7 +524,7 @@ impl> Vect // Needed to be sure even components not present on `x` are multiplied. *self *= beta; - for i in 0..x.nvalues() { + for i in 0..x.len() { unsafe { let k = x.data.row_index_unchecked(i); let y = self.vget_unchecked_mut(k); @@ -479,8 +576,7 @@ where "Mismatched dimensions for matrix multiplication." ); - let mut res = - CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues()); + let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); let mut timestamps = VectorN::zeros_generic(nrows1, U1); let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) }; let mut nz = 0; @@ -540,8 +636,7 @@ where "Mismatched dimensions for matrix sum." ); - let mut res = - CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues()); + let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len()); let mut timestamps = VectorN::zeros_generic(nrows1, U1); let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) }; let mut nz = 0; @@ -582,9 +677,10 @@ where } } +use std::fmt::Debug; impl<'a, N: Scalar + Zero, R: Dim, C: Dim, S> From> for MatrixMN where - S: CsStorage, + S: CsStorage + Debug, DefaultAllocator: Allocator, { fn from(m: CsMatrix) -> Self { @@ -608,8 +704,8 @@ where { fn from(m: Matrix) -> Self { let (nrows, ncols) = m.data.shape(); - let nvalues = m.iter().filter(|e| !e.is_zero()).count(); - let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, nvalues); + let len = m.iter().filter(|e| !e.is_zero()).count(); + let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, len); let mut nz = 0; for j in 0..ncols.value() { diff --git a/tests/sparse/cs_linalg.rs b/tests/sparse/cs_linalg.rs new file mode 100644 index 00000000..d65b633c --- /dev/null +++ b/tests/sparse/cs_linalg.rs @@ -0,0 +1,106 @@ +#![cfg_attr(rustfmt, rustfmt_skip)] + +use na::{CsMatrix, CsVector, Matrix5, Vector5}; + + +#[test] +fn cs_lower_triangular_solve() { + let a = Matrix5::new( + 4.0, 1.0, 4.0, 0.0, 9.0, + 5.0, 6.0, 0.0, 8.0, 10.0, + 9.0, 10.0, 11.0, 12.0, 0.0, + 0.0, -8.0, 3.0, 5.0, 9.0, + 0.0, 0.0, 1.0, 0.0, -10.0 + ); + let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0); + + let cs_a: CsMatrix<_, _, _> = a.into(); + + assert_eq!(cs_a.solve_lower_triangular(&b), a.solve_lower_triangular(&b)); +} + +#[test] +fn cs_tr_lower_triangular_solve() { + let a = Matrix5::new( + 4.0, 1.0, 4.0, 0.0, 9.0, + 5.0, 6.0, 0.0, 8.0, 10.0, + 9.0, 10.0, 11.0, 12.0, 0.0, + 0.0, -8.0, 3.0, 5.0, 9.0, + 0.0, 0.0, 1.0, 0.0, -10.0 + ); + let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0); + + let cs_a: CsMatrix<_, _, _> = a.into(); + + assert!(cs_a.tr_solve_lower_triangular(&b).is_some()); + assert_eq!(cs_a.tr_solve_lower_triangular(&b), a.tr_solve_lower_triangular(&b)); + + // Singular case. + let a = Matrix5::new( + 4.0, 1.0, 4.0, 0.0, 9.0, + 5.0, 6.0, 0.0, 8.0, 10.0, + 9.0, 10.0, 0.0, 12.0, 0.0, + 0.0, -8.0, 3.0, 5.0, 9.0, + 0.0, 0.0, 1.0, 0.0, -10.0 + ); + let cs_a: CsMatrix<_, _, _> = a.into(); + + assert!(cs_a.tr_solve_lower_triangular(&b).is_none()); +} + + +#[test] +fn cs_lower_triangular_solve_cs() { + let a = Matrix5::new( + 4.0, 1.0, 4.0, 0.0, 9.0, + 5.0, 6.0, 0.0, 8.0, 10.0, + 9.0, 10.0, 11.0, 12.0, 0.0, + 0.0, -8.0, 3.0, 5.0, 9.0, + 0.0, 0.0, 1.0, 0.0, -10.0 + ); + let b1 = Vector5::zeros(); + let b2 = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0); + let b3 = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0); + let b4 = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0); + let b5 = Vector5::x(); + let b6 = Vector5::y(); + let b7 = Vector5::z(); + let b8 = Vector5::w(); + let b9 = Vector5::a(); + + let cs_a: CsMatrix<_, _, _> = a.into(); + let cs_b1: CsVector<_, _> = Vector5::zeros().into(); + let cs_b2: CsVector<_, _> = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0).into(); + let cs_b3: CsVector<_, _> = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0).into(); + let cs_b4: CsVector<_, _> = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0).into(); + let cs_b5: CsVector<_, _> = Vector5::x().into(); + let cs_b6: CsVector<_, _> = Vector5::y().into(); + let cs_b7: CsVector<_, _> = Vector5::z().into(); + let cs_b8: CsVector<_, _> = Vector5::w().into(); + let cs_b9: CsVector<_, _> = Vector5::a().into(); + + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3)); + assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4)); + + + // Singular case. + let a = Matrix5::new( + 4.0, 1.0, 4.0, 0.0, 9.0, + 5.0, 0.0, 0.0, 8.0, 10.0, + 9.0, 10.0, 0.0, 12.0, 0.0, + 0.0, -8.0, 3.0, 5.0, 9.0, + 0.0, 0.0, 1.0, 0.0, -10.0 + ); + let cs_a: CsMatrix<_, _, _> = a.into(); + + assert!(cs_a.solve_lower_triangular_cs(&cs_b2).is_none()); + assert!(cs_a.solve_lower_triangular_cs(&cs_b3).is_none()); + assert!(cs_a.solve_lower_triangular_cs(&cs_b4).is_none()); +} diff --git a/tests/sparse/mod.rs b/tests/sparse/mod.rs index 79cdaa0e..6c4d6d45 100644 --- a/tests/sparse/mod.rs +++ b/tests/sparse/mod.rs @@ -2,3 +2,4 @@ mod cs_construction; mod cs_conversion; mod cs_matrix; mod cs_ops; +mod cs_linalg; \ No newline at end of file