Add lower triangular solve with sparse right-hand-side.

This commit is contained in:
sebcrozet 2018-10-23 18:18:05 +02:00
parent e4e5659405
commit 34b20dc291
3 changed files with 257 additions and 54 deletions

View File

@ -6,9 +6,9 @@ use std::ops::{Add, Mul, Range};
use std::slice; use std::slice;
use allocator::Allocator; use allocator::Allocator;
use constraint::{AreMultipliable, DimEq, ShapeConstraint, SameNumberOfRows}; use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
use storage::{Storage, StorageMut}; use storage::{Storage, StorageMut};
use {Real, DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
// FIXME: this structure exists for now only because impl trait // FIXME: this structure exists for now only because impl trait
// cannot be used for trait method return types. // cannot be used for trait method return types.
@ -20,11 +20,12 @@ pub trait CsStorageIter<'a, N, R, C = U1> {
pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> { pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> {
fn shape(&self) -> (R, C); fn shape(&self) -> (R, C);
fn nvalues(&self) -> usize;
unsafe fn row_index_unchecked(&self, i: usize) -> usize; unsafe fn row_index_unchecked(&self, i: usize) -> usize;
unsafe fn get_value_unchecked(&self, i: usize) -> &N; unsafe fn get_value_unchecked(&self, i: usize) -> &N;
fn get_value(&self, i: usize) -> &N; fn get_value(&self, i: usize) -> &N;
fn row_index(&self, i: usize) -> usize; fn row_index(&self, i: usize) -> usize;
fn column_range(&self, i: usize) -> Range<usize>;
fn len(&self) -> usize;
} }
pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> { pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
@ -49,21 +50,7 @@ where
vals: Vec<N>, vals: Vec<N>,
} }
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C> impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C> where DefaultAllocator: Allocator<usize, C> {}
where
DefaultAllocator: Allocator<usize, C>,
{
#[inline]
fn column_range(&self, j: usize) -> Range<usize> {
let end = if j + 1 == self.p.len() {
self.nvalues()
} else {
self.p[j + 1]
};
self.p[j]..end
}
}
impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C> impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C>
where where
@ -92,7 +79,7 @@ where
} }
#[inline] #[inline]
fn nvalues(&self) -> usize { fn len(&self) -> usize {
self.vals.len() self.vals.len()
} }
@ -115,6 +102,17 @@ where
fn get_value(&self, i: usize) -> &N { fn get_value(&self, i: usize) -> &N {
&self.vals[i] &self.vals[i]
} }
#[inline]
fn column_range(&self, j: usize) -> Range<usize> {
let end = if j + 1 == self.p.len() {
self.len()
} else {
self.p[j + 1]
};
self.p[j]..end
}
} }
/* /*
@ -154,7 +152,7 @@ where
CsMatrix { CsMatrix {
data: CsVecStorage { data: CsVecStorage {
shape: (nrows, ncols), shape: (nrows, ncols),
p: unsafe { VectorN::new_uninitialized_generic(ncols, U1) }, p: VectorN::zeros_generic(ncols, U1),
i, i,
vals, vals,
}, },
@ -180,8 +178,8 @@ where
} }
impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> { impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
pub fn nvalues(&self) -> usize { pub fn len(&self) -> usize {
self.data.nvalues() self.data.len()
} }
pub fn transpose(&self) -> CsMatrix<N, C, R> pub fn transpose(&self) -> CsMatrix<N, C, R>
@ -190,7 +188,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
let nvals = self.nvalues(); let nvals = self.len();
let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals); let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals);
let mut workspace = Vector::zeros_generic(nrows, U1); let mut workspace = Vector::zeros_generic(nrows, U1);
@ -324,7 +322,6 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
true true
} }
pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>( pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
&self, &self,
b: &mut Matrix<N, R2, C2, S2>, b: &mut Matrix<N, R2, C2, S2>,
@ -369,6 +366,106 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
true true
} }
pub fn solve_lower_triangular_cs<D2: Dim, S2>(
&self,
b: &CsVector<N, D2, S2>,
) -> Option<CsVector<N, D2>>
where
S2: CsStorage<N, D2>,
DefaultAllocator: Allocator<bool, D> + Allocator<N, D2> + Allocator<usize, D2>,
ShapeConstraint: SameNumberOfRows<D, D2>,
{
let mut reach = Vec::new();
self.lower_triangular_reach(b, &mut reach);
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
for i in reach.iter().cloned() {
workspace[i] = N::zero();
}
for (i, val) in b.data.column_entries(0) {
workspace[i] = val;
}
for j in reach.iter().cloned().rev() {
let mut column = self.data.column_entries(j);
let mut diag_found = false;
while let Some((i, val)) = column.next() {
if i == j {
if val.is_zero() {
break;
}
workspace[j] /= val;
diag_found = true;
break;
}
}
if !diag_found {
return None;
}
for (i, val) in column {
workspace[i] -= workspace[j] * val;
}
}
// Copy the result into a sparse vector.
let mut result = CsVector::new_uninitialized_generic(b.data.shape().0, U1, reach.len());
for (i, val) in reach.iter().zip(result.data.vals.iter_mut()) {
*val = workspace[*i];
}
result.data.i = reach;
Some(result)
}
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
where
S2: CsStorage<N, D2>,
DefaultAllocator: Allocator<bool, D>,
{
let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false);
let mut stack = Vec::new();
for i in b.data.column_range(0) {
let row_index = b.data.row_index(i);
if !visited[row_index] {
let rng = self.data.column_range(row_index);
stack.push((row_index, rng));
self.lower_triangular_dfs(visited.as_mut_slice(), &mut stack, xi);
}
}
}
fn lower_triangular_dfs(
&self,
visited: &mut [bool],
stack: &mut Vec<(usize, Range<usize>)>,
xi: &mut Vec<usize>,
) {
'recursion: while let Some((j, rng)) = stack.pop() {
visited[j] = true;
for i in rng.clone() {
let row_id = self.data.row_index(i);
if row_id > j && !visited[row_id] {
stack.push((j, (i + 1)..rng.end));
let row_id = self.data.row_index(i);
stack.push((row_id, self.data.column_range(row_id)));
continue 'recursion;
}
}
xi.push(j)
}
}
} }
/* /*
@ -381,8 +478,8 @@ impl<N: Scalar, R, S> CsVector<N, R, S> {
self.data.set_column_len(0, nnzero); self.data.set_column_len(0, nnzero);
// Fill with the axpy. // Fill with the axpy.
let mut i = self.nvalues(); let mut i = self.len();
let mut j = x.nvalues(); let mut j = x.len();
let mut k = nnzero - 1; let mut k = nnzero - 1;
let mut rid1 = self.data.row_index(0, i - 1); let mut rid1 = self.data.row_index(0, i - 1);
let mut rid2 = x.data.row_index(0, j - 1); let mut rid2 = x.data.row_index(0, j - 1);
@ -416,7 +513,7 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
ShapeConstraint: DimEq<D, D2>, ShapeConstraint: DimEq<D, D2>,
{ {
if beta.is_zero() { if beta.is_zero() {
for i in 0..x.nvalues() { for i in 0..x.len() {
unsafe { unsafe {
let k = x.data.row_index_unchecked(i); let k = x.data.row_index_unchecked(i);
let y = self.vget_unchecked_mut(k); let y = self.vget_unchecked_mut(k);
@ -427,7 +524,7 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
// Needed to be sure even components not present on `x` are multiplied. // Needed to be sure even components not present on `x` are multiplied.
*self *= beta; *self *= beta;
for i in 0..x.nvalues() { for i in 0..x.len() {
unsafe { unsafe {
let k = x.data.row_index_unchecked(i); let k = x.data.row_index_unchecked(i);
let y = self.vget_unchecked_mut(k); let y = self.vget_unchecked_mut(k);
@ -479,8 +576,7 @@ where
"Mismatched dimensions for matrix multiplication." "Mismatched dimensions for matrix multiplication."
); );
let mut res = let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues());
let mut timestamps = VectorN::zeros_generic(nrows1, U1); let mut timestamps = VectorN::zeros_generic(nrows1, U1);
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) }; let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
let mut nz = 0; let mut nz = 0;
@ -540,8 +636,7 @@ where
"Mismatched dimensions for matrix sum." "Mismatched dimensions for matrix sum."
); );
let mut res = let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues());
let mut timestamps = VectorN::zeros_generic(nrows1, U1); let mut timestamps = VectorN::zeros_generic(nrows1, U1);
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) }; let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
let mut nz = 0; let mut nz = 0;
@ -582,9 +677,10 @@ where
} }
} }
use std::fmt::Debug;
impl<'a, N: Scalar + Zero, R: Dim, C: Dim, S> From<CsMatrix<N, R, C, S>> for MatrixMN<N, R, C> impl<'a, N: Scalar + Zero, R: Dim, C: Dim, S> From<CsMatrix<N, R, C, S>> for MatrixMN<N, R, C>
where where
S: CsStorage<N, R, C>, S: CsStorage<N, R, C> + Debug,
DefaultAllocator: Allocator<N, R, C>, DefaultAllocator: Allocator<N, R, C>,
{ {
fn from(m: CsMatrix<N, R, C, S>) -> Self { fn from(m: CsMatrix<N, R, C, S>) -> Self {
@ -608,8 +704,8 @@ where
{ {
fn from(m: Matrix<N, R, C, S>) -> Self { fn from(m: Matrix<N, R, C, S>) -> Self {
let (nrows, ncols) = m.data.shape(); let (nrows, ncols) = m.data.shape();
let nvalues = m.iter().filter(|e| !e.is_zero()).count(); let len = m.iter().filter(|e| !e.is_zero()).count();
let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, nvalues); let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, len);
let mut nz = 0; let mut nz = 0;
for j in 0..ncols.value() { for j in 0..ncols.value() {

106
tests/sparse/cs_linalg.rs Normal file
View File

@ -0,0 +1,106 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use na::{CsMatrix, CsVector, Matrix5, Vector5};
#[test]
fn cs_lower_triangular_solve() {
let a = Matrix5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 6.0, 0.0, 8.0, 10.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, -8.0, 3.0, 5.0, 9.0,
0.0, 0.0, 1.0, 0.0, -10.0
);
let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
let cs_a: CsMatrix<_, _, _> = a.into();
assert_eq!(cs_a.solve_lower_triangular(&b), a.solve_lower_triangular(&b));
}
#[test]
fn cs_tr_lower_triangular_solve() {
let a = Matrix5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 6.0, 0.0, 8.0, 10.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, -8.0, 3.0, 5.0, 9.0,
0.0, 0.0, 1.0, 0.0, -10.0
);
let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
let cs_a: CsMatrix<_, _, _> = a.into();
assert!(cs_a.tr_solve_lower_triangular(&b).is_some());
assert_eq!(cs_a.tr_solve_lower_triangular(&b), a.tr_solve_lower_triangular(&b));
// Singular case.
let a = Matrix5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 6.0, 0.0, 8.0, 10.0,
9.0, 10.0, 0.0, 12.0, 0.0,
0.0, -8.0, 3.0, 5.0, 9.0,
0.0, 0.0, 1.0, 0.0, -10.0
);
let cs_a: CsMatrix<_, _, _> = a.into();
assert!(cs_a.tr_solve_lower_triangular(&b).is_none());
}
#[test]
fn cs_lower_triangular_solve_cs() {
let a = Matrix5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 6.0, 0.0, 8.0, 10.0,
9.0, 10.0, 11.0, 12.0, 0.0,
0.0, -8.0, 3.0, 5.0, 9.0,
0.0, 0.0, 1.0, 0.0, -10.0
);
let b1 = Vector5::zeros();
let b2 = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
let b3 = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0);
let b4 = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0);
let b5 = Vector5::x();
let b6 = Vector5::y();
let b7 = Vector5::z();
let b8 = Vector5::w();
let b9 = Vector5::a();
let cs_a: CsMatrix<_, _, _> = a.into();
let cs_b1: CsVector<_, _> = Vector5::zeros().into();
let cs_b2: CsVector<_, _> = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0).into();
let cs_b3: CsVector<_, _> = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0).into();
let cs_b4: CsVector<_, _> = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0).into();
let cs_b5: CsVector<_, _> = Vector5::x().into();
let cs_b6: CsVector<_, _> = Vector5::y().into();
let cs_b7: CsVector<_, _> = Vector5::z().into();
let cs_b8: CsVector<_, _> = Vector5::w().into();
let cs_b9: CsVector<_, _> = Vector5::a().into();
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4));
// Singular case.
let a = Matrix5::new(
4.0, 1.0, 4.0, 0.0, 9.0,
5.0, 0.0, 0.0, 8.0, 10.0,
9.0, 10.0, 0.0, 12.0, 0.0,
0.0, -8.0, 3.0, 5.0, 9.0,
0.0, 0.0, 1.0, 0.0, -10.0
);
let cs_a: CsMatrix<_, _, _> = a.into();
assert!(cs_a.solve_lower_triangular_cs(&cs_b2).is_none());
assert!(cs_a.solve_lower_triangular_cs(&cs_b3).is_none());
assert!(cs_a.solve_lower_triangular_cs(&cs_b4).is_none());
}

View File

@ -2,3 +2,4 @@ mod cs_construction;
mod cs_conversion; mod cs_conversion;
mod cs_matrix; mod cs_matrix;
mod cs_ops; mod cs_ops;
mod cs_linalg;