Add lower triangular solve with sparse right-hand-side.
This commit is contained in:
parent
e4e5659405
commit
34b20dc291
|
@ -6,9 +6,9 @@ use std::ops::{Add, Mul, Range};
|
|||
use std::slice;
|
||||
|
||||
use allocator::Allocator;
|
||||
use constraint::{AreMultipliable, DimEq, ShapeConstraint, SameNumberOfRows};
|
||||
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
|
||||
use storage::{Storage, StorageMut};
|
||||
use {Real, DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1};
|
||||
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
|
||||
|
||||
// FIXME: this structure exists for now only because impl trait
|
||||
// cannot be used for trait method return types.
|
||||
|
@ -20,11 +20,12 @@ pub trait CsStorageIter<'a, N, R, C = U1> {
|
|||
|
||||
pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> {
|
||||
fn shape(&self) -> (R, C);
|
||||
fn nvalues(&self) -> usize;
|
||||
unsafe fn row_index_unchecked(&self, i: usize) -> usize;
|
||||
unsafe fn get_value_unchecked(&self, i: usize) -> &N;
|
||||
fn get_value(&self, i: usize) -> &N;
|
||||
fn row_index(&self, i: usize) -> usize;
|
||||
fn column_range(&self, i: usize) -> Range<usize>;
|
||||
fn len(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
|
||||
|
@ -49,21 +50,7 @@ where
|
|||
vals: Vec<N>,
|
||||
}
|
||||
|
||||
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<usize, C>,
|
||||
{
|
||||
#[inline]
|
||||
fn column_range(&self, j: usize) -> Range<usize> {
|
||||
let end = if j + 1 == self.p.len() {
|
||||
self.nvalues()
|
||||
} else {
|
||||
self.p[j + 1]
|
||||
};
|
||||
|
||||
self.p[j]..end
|
||||
}
|
||||
}
|
||||
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C> where DefaultAllocator: Allocator<usize, C> {}
|
||||
|
||||
impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C>
|
||||
where
|
||||
|
@ -92,7 +79,7 @@ where
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn nvalues(&self) -> usize {
|
||||
fn len(&self) -> usize {
|
||||
self.vals.len()
|
||||
}
|
||||
|
||||
|
@ -115,6 +102,17 @@ where
|
|||
fn get_value(&self, i: usize) -> &N {
|
||||
&self.vals[i]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn column_range(&self, j: usize) -> Range<usize> {
|
||||
let end = if j + 1 == self.p.len() {
|
||||
self.len()
|
||||
} else {
|
||||
self.p[j + 1]
|
||||
};
|
||||
|
||||
self.p[j]..end
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -154,7 +152,7 @@ where
|
|||
CsMatrix {
|
||||
data: CsVecStorage {
|
||||
shape: (nrows, ncols),
|
||||
p: unsafe { VectorN::new_uninitialized_generic(ncols, U1) },
|
||||
p: VectorN::zeros_generic(ncols, U1),
|
||||
i,
|
||||
vals,
|
||||
},
|
||||
|
@ -180,8 +178,8 @@ where
|
|||
}
|
||||
|
||||
impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||
pub fn nvalues(&self) -> usize {
|
||||
self.data.nvalues()
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn transpose(&self) -> CsMatrix<N, C, R>
|
||||
|
@ -190,7 +188,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
|||
{
|
||||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
let nvals = self.nvalues();
|
||||
let nvals = self.len();
|
||||
let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals);
|
||||
let mut workspace = Vector::zeros_generic(nrows, U1);
|
||||
|
||||
|
@ -324,7 +322,6 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
|||
true
|
||||
}
|
||||
|
||||
|
||||
pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
|
||||
&self,
|
||||
b: &mut Matrix<N, R2, C2, S2>,
|
||||
|
@ -369,6 +366,106 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
|||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn solve_lower_triangular_cs<D2: Dim, S2>(
|
||||
&self,
|
||||
b: &CsVector<N, D2, S2>,
|
||||
) -> Option<CsVector<N, D2>>
|
||||
where
|
||||
S2: CsStorage<N, D2>,
|
||||
DefaultAllocator: Allocator<bool, D> + Allocator<N, D2> + Allocator<usize, D2>,
|
||||
ShapeConstraint: SameNumberOfRows<D, D2>,
|
||||
{
|
||||
let mut reach = Vec::new();
|
||||
self.lower_triangular_reach(b, &mut reach);
|
||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
|
||||
|
||||
for i in reach.iter().cloned() {
|
||||
workspace[i] = N::zero();
|
||||
}
|
||||
|
||||
for (i, val) in b.data.column_entries(0) {
|
||||
workspace[i] = val;
|
||||
}
|
||||
|
||||
for j in reach.iter().cloned().rev() {
|
||||
let mut column = self.data.column_entries(j);
|
||||
let mut diag_found = false;
|
||||
|
||||
while let Some((i, val)) = column.next() {
|
||||
if i == j {
|
||||
if val.is_zero() {
|
||||
break;
|
||||
}
|
||||
|
||||
workspace[j] /= val;
|
||||
diag_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !diag_found {
|
||||
return None;
|
||||
}
|
||||
|
||||
for (i, val) in column {
|
||||
workspace[i] -= workspace[j] * val;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the result into a sparse vector.
|
||||
let mut result = CsVector::new_uninitialized_generic(b.data.shape().0, U1, reach.len());
|
||||
|
||||
for (i, val) in reach.iter().zip(result.data.vals.iter_mut()) {
|
||||
*val = workspace[*i];
|
||||
}
|
||||
|
||||
result.data.i = reach;
|
||||
Some(result)
|
||||
}
|
||||
|
||||
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
|
||||
where
|
||||
S2: CsStorage<N, D2>,
|
||||
DefaultAllocator: Allocator<bool, D>,
|
||||
{
|
||||
let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false);
|
||||
let mut stack = Vec::new();
|
||||
|
||||
for i in b.data.column_range(0) {
|
||||
let row_index = b.data.row_index(i);
|
||||
|
||||
if !visited[row_index] {
|
||||
let rng = self.data.column_range(row_index);
|
||||
stack.push((row_index, rng));
|
||||
self.lower_triangular_dfs(visited.as_mut_slice(), &mut stack, xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lower_triangular_dfs(
|
||||
&self,
|
||||
visited: &mut [bool],
|
||||
stack: &mut Vec<(usize, Range<usize>)>,
|
||||
xi: &mut Vec<usize>,
|
||||
) {
|
||||
'recursion: while let Some((j, rng)) = stack.pop() {
|
||||
visited[j] = true;
|
||||
|
||||
for i in rng.clone() {
|
||||
let row_id = self.data.row_index(i);
|
||||
if row_id > j && !visited[row_id] {
|
||||
stack.push((j, (i + 1)..rng.end));
|
||||
|
||||
let row_id = self.data.row_index(i);
|
||||
stack.push((row_id, self.data.column_range(row_id)));
|
||||
continue 'recursion;
|
||||
}
|
||||
}
|
||||
|
||||
xi.push(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -381,8 +478,8 @@ impl<N: Scalar, R, S> CsVector<N, R, S> {
|
|||
self.data.set_column_len(0, nnzero);
|
||||
|
||||
// Fill with the axpy.
|
||||
let mut i = self.nvalues();
|
||||
let mut j = x.nvalues();
|
||||
let mut i = self.len();
|
||||
let mut j = x.len();
|
||||
let mut k = nnzero - 1;
|
||||
let mut rid1 = self.data.row_index(0, i - 1);
|
||||
let mut rid2 = x.data.row_index(0, j - 1);
|
||||
|
@ -416,7 +513,7 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
|
|||
ShapeConstraint: DimEq<D, D2>,
|
||||
{
|
||||
if beta.is_zero() {
|
||||
for i in 0..x.nvalues() {
|
||||
for i in 0..x.len() {
|
||||
unsafe {
|
||||
let k = x.data.row_index_unchecked(i);
|
||||
let y = self.vget_unchecked_mut(k);
|
||||
|
@ -427,7 +524,7 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
|
|||
// Needed to be sure even components not present on `x` are multiplied.
|
||||
*self *= beta;
|
||||
|
||||
for i in 0..x.nvalues() {
|
||||
for i in 0..x.len() {
|
||||
unsafe {
|
||||
let k = x.data.row_index_unchecked(i);
|
||||
let y = self.vget_unchecked_mut(k);
|
||||
|
@ -479,8 +576,7 @@ where
|
|||
"Mismatched dimensions for matrix multiplication."
|
||||
);
|
||||
|
||||
let mut res =
|
||||
CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues());
|
||||
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
||||
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
||||
let mut nz = 0;
|
||||
|
@ -540,8 +636,7 @@ where
|
|||
"Mismatched dimensions for matrix sum."
|
||||
);
|
||||
|
||||
let mut res =
|
||||
CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues());
|
||||
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
||||
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
||||
let mut nz = 0;
|
||||
|
@ -582,9 +677,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
use std::fmt::Debug;
|
||||
impl<'a, N: Scalar + Zero, R: Dim, C: Dim, S> From<CsMatrix<N, R, C, S>> for MatrixMN<N, R, C>
|
||||
where
|
||||
S: CsStorage<N, R, C>,
|
||||
S: CsStorage<N, R, C> + Debug,
|
||||
DefaultAllocator: Allocator<N, R, C>,
|
||||
{
|
||||
fn from(m: CsMatrix<N, R, C, S>) -> Self {
|
||||
|
@ -608,8 +704,8 @@ where
|
|||
{
|
||||
fn from(m: Matrix<N, R, C, S>) -> Self {
|
||||
let (nrows, ncols) = m.data.shape();
|
||||
let nvalues = m.iter().filter(|e| !e.is_zero()).count();
|
||||
let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, nvalues);
|
||||
let len = m.iter().filter(|e| !e.is_zero()).count();
|
||||
let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, len);
|
||||
let mut nz = 0;
|
||||
|
||||
for j in 0..ncols.value() {
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
#![cfg_attr(rustfmt, rustfmt_skip)]
|
||||
|
||||
use na::{CsMatrix, CsVector, Matrix5, Vector5};
|
||||
|
||||
|
||||
#[test]
|
||||
fn cs_lower_triangular_solve() {
|
||||
let a = Matrix5::new(
|
||||
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||
9.0, 10.0, 11.0, 12.0, 0.0,
|
||||
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||
0.0, 0.0, 1.0, 0.0, -10.0
|
||||
);
|
||||
let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
|
||||
|
||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||
|
||||
assert_eq!(cs_a.solve_lower_triangular(&b), a.solve_lower_triangular(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cs_tr_lower_triangular_solve() {
|
||||
let a = Matrix5::new(
|
||||
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||
9.0, 10.0, 11.0, 12.0, 0.0,
|
||||
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||
0.0, 0.0, 1.0, 0.0, -10.0
|
||||
);
|
||||
let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
|
||||
|
||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||
|
||||
assert!(cs_a.tr_solve_lower_triangular(&b).is_some());
|
||||
assert_eq!(cs_a.tr_solve_lower_triangular(&b), a.tr_solve_lower_triangular(&b));
|
||||
|
||||
// Singular case.
|
||||
let a = Matrix5::new(
|
||||
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||
9.0, 10.0, 0.0, 12.0, 0.0,
|
||||
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||
0.0, 0.0, 1.0, 0.0, -10.0
|
||||
);
|
||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||
|
||||
assert!(cs_a.tr_solve_lower_triangular(&b).is_none());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn cs_lower_triangular_solve_cs() {
|
||||
let a = Matrix5::new(
|
||||
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||
9.0, 10.0, 11.0, 12.0, 0.0,
|
||||
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||
0.0, 0.0, 1.0, 0.0, -10.0
|
||||
);
|
||||
let b1 = Vector5::zeros();
|
||||
let b2 = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
|
||||
let b3 = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0);
|
||||
let b4 = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0);
|
||||
let b5 = Vector5::x();
|
||||
let b6 = Vector5::y();
|
||||
let b7 = Vector5::z();
|
||||
let b8 = Vector5::w();
|
||||
let b9 = Vector5::a();
|
||||
|
||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||
let cs_b1: CsVector<_, _> = Vector5::zeros().into();
|
||||
let cs_b2: CsVector<_, _> = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0).into();
|
||||
let cs_b3: CsVector<_, _> = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0).into();
|
||||
let cs_b4: CsVector<_, _> = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0).into();
|
||||
let cs_b5: CsVector<_, _> = Vector5::x().into();
|
||||
let cs_b6: CsVector<_, _> = Vector5::y().into();
|
||||
let cs_b7: CsVector<_, _> = Vector5::z().into();
|
||||
let cs_b8: CsVector<_, _> = Vector5::w().into();
|
||||
let cs_b9: CsVector<_, _> = Vector5::a().into();
|
||||
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4));
|
||||
|
||||
|
||||
// Singular case.
|
||||
let a = Matrix5::new(
|
||||
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||
5.0, 0.0, 0.0, 8.0, 10.0,
|
||||
9.0, 10.0, 0.0, 12.0, 0.0,
|
||||
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||
0.0, 0.0, 1.0, 0.0, -10.0
|
||||
);
|
||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||
|
||||
assert!(cs_a.solve_lower_triangular_cs(&cs_b2).is_none());
|
||||
assert!(cs_a.solve_lower_triangular_cs(&cs_b3).is_none());
|
||||
assert!(cs_a.solve_lower_triangular_cs(&cs_b4).is_none());
|
||||
}
|
|
@ -2,3 +2,4 @@ mod cs_construction;
|
|||
mod cs_conversion;
|
||||
mod cs_matrix;
|
||||
mod cs_ops;
|
||||
mod cs_linalg;
|
Loading…
Reference in New Issue