Add lower triangular solve with sparse right-hand-side.
This commit is contained in:
parent
e4e5659405
commit
34b20dc291
|
@ -6,9 +6,9 @@ use std::ops::{Add, Mul, Range};
|
||||||
use std::slice;
|
use std::slice;
|
||||||
|
|
||||||
use allocator::Allocator;
|
use allocator::Allocator;
|
||||||
use constraint::{AreMultipliable, DimEq, ShapeConstraint, SameNumberOfRows};
|
use constraint::{AreMultipliable, DimEq, SameNumberOfRows, ShapeConstraint};
|
||||||
use storage::{Storage, StorageMut};
|
use storage::{Storage, StorageMut};
|
||||||
use {Real, DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1};
|
use {DefaultAllocator, Dim, Matrix, MatrixMN, Real, Scalar, Vector, VectorN, U1};
|
||||||
|
|
||||||
// FIXME: this structure exists for now only because impl trait
|
// FIXME: this structure exists for now only because impl trait
|
||||||
// cannot be used for trait method return types.
|
// cannot be used for trait method return types.
|
||||||
|
@ -20,11 +20,12 @@ pub trait CsStorageIter<'a, N, R, C = U1> {
|
||||||
|
|
||||||
pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> {
|
pub trait CsStorage<N, R, C = U1>: for<'a> CsStorageIter<'a, N, R, C> {
|
||||||
fn shape(&self) -> (R, C);
|
fn shape(&self) -> (R, C);
|
||||||
fn nvalues(&self) -> usize;
|
|
||||||
unsafe fn row_index_unchecked(&self, i: usize) -> usize;
|
unsafe fn row_index_unchecked(&self, i: usize) -> usize;
|
||||||
unsafe fn get_value_unchecked(&self, i: usize) -> &N;
|
unsafe fn get_value_unchecked(&self, i: usize) -> &N;
|
||||||
fn get_value(&self, i: usize) -> &N;
|
fn get_value(&self, i: usize) -> &N;
|
||||||
fn row_index(&self, i: usize) -> usize;
|
fn row_index(&self, i: usize) -> usize;
|
||||||
|
fn column_range(&self, i: usize) -> Range<usize>;
|
||||||
|
fn len(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
|
pub trait CsStorageMut<N, R, C = U1>: CsStorage<N, R, C> {
|
||||||
|
@ -49,21 +50,7 @@ where
|
||||||
vals: Vec<N>,
|
vals: Vec<N>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C>
|
impl<N: Scalar, R: Dim, C: Dim> CsVecStorage<N, R, C> where DefaultAllocator: Allocator<usize, C> {}
|
||||||
where
|
|
||||||
DefaultAllocator: Allocator<usize, C>,
|
|
||||||
{
|
|
||||||
#[inline]
|
|
||||||
fn column_range(&self, j: usize) -> Range<usize> {
|
|
||||||
let end = if j + 1 == self.p.len() {
|
|
||||||
self.nvalues()
|
|
||||||
} else {
|
|
||||||
self.p[j + 1]
|
|
||||||
};
|
|
||||||
|
|
||||||
self.p[j]..end
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C>
|
impl<'a, N: Scalar, R: Dim, C: Dim> CsStorageIter<'a, N, R, C> for CsVecStorage<N, R, C>
|
||||||
where
|
where
|
||||||
|
@ -92,7 +79,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn nvalues(&self) -> usize {
|
fn len(&self) -> usize {
|
||||||
self.vals.len()
|
self.vals.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,6 +102,17 @@ where
|
||||||
fn get_value(&self, i: usize) -> &N {
|
fn get_value(&self, i: usize) -> &N {
|
||||||
&self.vals[i]
|
&self.vals[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn column_range(&self, j: usize) -> Range<usize> {
|
||||||
|
let end = if j + 1 == self.p.len() {
|
||||||
|
self.len()
|
||||||
|
} else {
|
||||||
|
self.p[j + 1]
|
||||||
|
};
|
||||||
|
|
||||||
|
self.p[j]..end
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -154,7 +152,7 @@ where
|
||||||
CsMatrix {
|
CsMatrix {
|
||||||
data: CsVecStorage {
|
data: CsVecStorage {
|
||||||
shape: (nrows, ncols),
|
shape: (nrows, ncols),
|
||||||
p: unsafe { VectorN::new_uninitialized_generic(ncols, U1) },
|
p: VectorN::zeros_generic(ncols, U1),
|
||||||
i,
|
i,
|
||||||
vals,
|
vals,
|
||||||
},
|
},
|
||||||
|
@ -180,8 +178,8 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||||
pub fn nvalues(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
self.data.nvalues()
|
self.data.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transpose(&self) -> CsMatrix<N, C, R>
|
pub fn transpose(&self) -> CsMatrix<N, C, R>
|
||||||
|
@ -190,7 +188,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||||
{
|
{
|
||||||
let (nrows, ncols) = self.data.shape();
|
let (nrows, ncols) = self.data.shape();
|
||||||
|
|
||||||
let nvals = self.nvalues();
|
let nvals = self.len();
|
||||||
let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals);
|
let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals);
|
||||||
let mut workspace = Vector::zeros_generic(nrows, U1);
|
let mut workspace = Vector::zeros_generic(nrows, U1);
|
||||||
|
|
||||||
|
@ -250,27 +248,27 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||||
&self,
|
&self,
|
||||||
b: &Matrix<N, R2, C2, S2>,
|
b: &Matrix<N, R2, C2, S2>,
|
||||||
) -> Option<MatrixMN<N, R2, C2>>
|
) -> Option<MatrixMN<N, R2, C2>>
|
||||||
where
|
where
|
||||||
S2: Storage<N, R2, C2>,
|
S2: Storage<N, R2, C2>,
|
||||||
DefaultAllocator: Allocator<N, R2, C2>,
|
DefaultAllocator: Allocator<N, R2, C2>,
|
||||||
ShapeConstraint: SameNumberOfRows<D, R2>,
|
ShapeConstraint: SameNumberOfRows<D, R2>,
|
||||||
{
|
{
|
||||||
let mut b = b.clone_owned();
|
let mut b = b.clone_owned();
|
||||||
if self.solve_lower_triangular_mut(&mut b) {
|
if self.solve_lower_triangular_mut(&mut b) {
|
||||||
Some(b)
|
Some(b)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tr_solve_lower_triangular<R2: Dim, C2: Dim, S2>(
|
pub fn tr_solve_lower_triangular<R2: Dim, C2: Dim, S2>(
|
||||||
&self,
|
&self,
|
||||||
b: &Matrix<N, R2, C2, S2>,
|
b: &Matrix<N, R2, C2, S2>,
|
||||||
) -> Option<MatrixMN<N, R2, C2>>
|
) -> Option<MatrixMN<N, R2, C2>>
|
||||||
where
|
where
|
||||||
S2: Storage<N, R2, C2>,
|
S2: Storage<N, R2, C2>,
|
||||||
DefaultAllocator: Allocator<N, R2, C2>,
|
DefaultAllocator: Allocator<N, R2, C2>,
|
||||||
ShapeConstraint: SameNumberOfRows<D, R2>,
|
ShapeConstraint: SameNumberOfRows<D, R2>,
|
||||||
{
|
{
|
||||||
let mut b = b.clone_owned();
|
let mut b = b.clone_owned();
|
||||||
if self.tr_solve_lower_triangular_mut(&mut b) {
|
if self.tr_solve_lower_triangular_mut(&mut b) {
|
||||||
|
@ -284,9 +282,9 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||||
&self,
|
&self,
|
||||||
b: &mut Matrix<N, R2, C2, S2>,
|
b: &mut Matrix<N, R2, C2, S2>,
|
||||||
) -> bool
|
) -> bool
|
||||||
where
|
where
|
||||||
S2: StorageMut<N, R2, C2>,
|
S2: StorageMut<N, R2, C2>,
|
||||||
ShapeConstraint: SameNumberOfRows<D, R2>,
|
ShapeConstraint: SameNumberOfRows<D, R2>,
|
||||||
{
|
{
|
||||||
let (nrows, ncols) = self.data.shape();
|
let (nrows, ncols) = self.data.shape();
|
||||||
assert_eq!(nrows.value(), ncols.value(), "The matrix must be square.");
|
assert_eq!(nrows.value(), ncols.value(), "The matrix must be square.");
|
||||||
|
@ -324,14 +322,13 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
|
pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
|
||||||
&self,
|
&self,
|
||||||
b: &mut Matrix<N, R2, C2, S2>,
|
b: &mut Matrix<N, R2, C2, S2>,
|
||||||
) -> bool
|
) -> bool
|
||||||
where
|
where
|
||||||
S2: StorageMut<N, R2, C2>,
|
S2: StorageMut<N, R2, C2>,
|
||||||
ShapeConstraint: SameNumberOfRows<D, R2>,
|
ShapeConstraint: SameNumberOfRows<D, R2>,
|
||||||
{
|
{
|
||||||
let (nrows, ncols) = self.data.shape();
|
let (nrows, ncols) = self.data.shape();
|
||||||
assert_eq!(nrows.value(), ncols.value(), "The matrix must be square.");
|
assert_eq!(nrows.value(), ncols.value(), "The matrix must be square.");
|
||||||
|
@ -369,6 +366,106 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn solve_lower_triangular_cs<D2: Dim, S2>(
|
||||||
|
&self,
|
||||||
|
b: &CsVector<N, D2, S2>,
|
||||||
|
) -> Option<CsVector<N, D2>>
|
||||||
|
where
|
||||||
|
S2: CsStorage<N, D2>,
|
||||||
|
DefaultAllocator: Allocator<bool, D> + Allocator<N, D2> + Allocator<usize, D2>,
|
||||||
|
ShapeConstraint: SameNumberOfRows<D, D2>,
|
||||||
|
{
|
||||||
|
let mut reach = Vec::new();
|
||||||
|
self.lower_triangular_reach(b, &mut reach);
|
||||||
|
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
|
||||||
|
|
||||||
|
for i in reach.iter().cloned() {
|
||||||
|
workspace[i] = N::zero();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, val) in b.data.column_entries(0) {
|
||||||
|
workspace[i] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in reach.iter().cloned().rev() {
|
||||||
|
let mut column = self.data.column_entries(j);
|
||||||
|
let mut diag_found = false;
|
||||||
|
|
||||||
|
while let Some((i, val)) = column.next() {
|
||||||
|
if i == j {
|
||||||
|
if val.is_zero() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace[j] /= val;
|
||||||
|
diag_found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !diag_found {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, val) in column {
|
||||||
|
workspace[i] -= workspace[j] * val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the result into a sparse vector.
|
||||||
|
let mut result = CsVector::new_uninitialized_generic(b.data.shape().0, U1, reach.len());
|
||||||
|
|
||||||
|
for (i, val) in reach.iter().zip(result.data.vals.iter_mut()) {
|
||||||
|
*val = workspace[*i];
|
||||||
|
}
|
||||||
|
|
||||||
|
result.data.i = reach;
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
|
||||||
|
where
|
||||||
|
S2: CsStorage<N, D2>,
|
||||||
|
DefaultAllocator: Allocator<bool, D>,
|
||||||
|
{
|
||||||
|
let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false);
|
||||||
|
let mut stack = Vec::new();
|
||||||
|
|
||||||
|
for i in b.data.column_range(0) {
|
||||||
|
let row_index = b.data.row_index(i);
|
||||||
|
|
||||||
|
if !visited[row_index] {
|
||||||
|
let rng = self.data.column_range(row_index);
|
||||||
|
stack.push((row_index, rng));
|
||||||
|
self.lower_triangular_dfs(visited.as_mut_slice(), &mut stack, xi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower_triangular_dfs(
|
||||||
|
&self,
|
||||||
|
visited: &mut [bool],
|
||||||
|
stack: &mut Vec<(usize, Range<usize>)>,
|
||||||
|
xi: &mut Vec<usize>,
|
||||||
|
) {
|
||||||
|
'recursion: while let Some((j, rng)) = stack.pop() {
|
||||||
|
visited[j] = true;
|
||||||
|
|
||||||
|
for i in rng.clone() {
|
||||||
|
let row_id = self.data.row_index(i);
|
||||||
|
if row_id > j && !visited[row_id] {
|
||||||
|
stack.push((j, (i + 1)..rng.end));
|
||||||
|
|
||||||
|
let row_id = self.data.row_index(i);
|
||||||
|
stack.push((row_id, self.data.column_range(row_id)));
|
||||||
|
continue 'recursion;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xi.push(j)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -381,8 +478,8 @@ impl<N: Scalar, R, S> CsVector<N, R, S> {
|
||||||
self.data.set_column_len(0, nnzero);
|
self.data.set_column_len(0, nnzero);
|
||||||
|
|
||||||
// Fill with the axpy.
|
// Fill with the axpy.
|
||||||
let mut i = self.nvalues();
|
let mut i = self.len();
|
||||||
let mut j = x.nvalues();
|
let mut j = x.len();
|
||||||
let mut k = nnzero - 1;
|
let mut k = nnzero - 1;
|
||||||
let mut rid1 = self.data.row_index(0, i - 1);
|
let mut rid1 = self.data.row_index(0, i - 1);
|
||||||
let mut rid2 = x.data.row_index(0, j - 1);
|
let mut rid2 = x.data.row_index(0, j - 1);
|
||||||
|
@ -416,7 +513,7 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
|
||||||
ShapeConstraint: DimEq<D, D2>,
|
ShapeConstraint: DimEq<D, D2>,
|
||||||
{
|
{
|
||||||
if beta.is_zero() {
|
if beta.is_zero() {
|
||||||
for i in 0..x.nvalues() {
|
for i in 0..x.len() {
|
||||||
unsafe {
|
unsafe {
|
||||||
let k = x.data.row_index_unchecked(i);
|
let k = x.data.row_index_unchecked(i);
|
||||||
let y = self.vget_unchecked_mut(k);
|
let y = self.vget_unchecked_mut(k);
|
||||||
|
@ -427,7 +524,7 @@ impl<N: Scalar + Zero + ClosedAdd + ClosedMul, D: Dim, S: StorageMut<N, D>> Vect
|
||||||
// Needed to be sure even components not present on `x` are multiplied.
|
// Needed to be sure even components not present on `x` are multiplied.
|
||||||
*self *= beta;
|
*self *= beta;
|
||||||
|
|
||||||
for i in 0..x.nvalues() {
|
for i in 0..x.len() {
|
||||||
unsafe {
|
unsafe {
|
||||||
let k = x.data.row_index_unchecked(i);
|
let k = x.data.row_index_unchecked(i);
|
||||||
let y = self.vget_unchecked_mut(k);
|
let y = self.vget_unchecked_mut(k);
|
||||||
|
@ -479,8 +576,7 @@ where
|
||||||
"Mismatched dimensions for matrix multiplication."
|
"Mismatched dimensions for matrix multiplication."
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut res =
|
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
||||||
CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues());
|
|
||||||
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
||||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
||||||
let mut nz = 0;
|
let mut nz = 0;
|
||||||
|
@ -540,8 +636,7 @@ where
|
||||||
"Mismatched dimensions for matrix sum."
|
"Mismatched dimensions for matrix sum."
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut res =
|
let mut res = CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.len() + rhs.len());
|
||||||
CsMatrix::new_uninitialized_generic(nrows1, ncols2, self.nvalues() + rhs.nvalues());
|
|
||||||
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
let mut timestamps = VectorN::zeros_generic(nrows1, U1);
|
||||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows1, U1) };
|
||||||
let mut nz = 0;
|
let mut nz = 0;
|
||||||
|
@ -582,9 +677,10 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use std::fmt::Debug;
|
||||||
impl<'a, N: Scalar + Zero, R: Dim, C: Dim, S> From<CsMatrix<N, R, C, S>> for MatrixMN<N, R, C>
|
impl<'a, N: Scalar + Zero, R: Dim, C: Dim, S> From<CsMatrix<N, R, C, S>> for MatrixMN<N, R, C>
|
||||||
where
|
where
|
||||||
S: CsStorage<N, R, C>,
|
S: CsStorage<N, R, C> + Debug,
|
||||||
DefaultAllocator: Allocator<N, R, C>,
|
DefaultAllocator: Allocator<N, R, C>,
|
||||||
{
|
{
|
||||||
fn from(m: CsMatrix<N, R, C, S>) -> Self {
|
fn from(m: CsMatrix<N, R, C, S>) -> Self {
|
||||||
|
@ -608,8 +704,8 @@ where
|
||||||
{
|
{
|
||||||
fn from(m: Matrix<N, R, C, S>) -> Self {
|
fn from(m: Matrix<N, R, C, S>) -> Self {
|
||||||
let (nrows, ncols) = m.data.shape();
|
let (nrows, ncols) = m.data.shape();
|
||||||
let nvalues = m.iter().filter(|e| !e.is_zero()).count();
|
let len = m.iter().filter(|e| !e.is_zero()).count();
|
||||||
let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, nvalues);
|
let mut res = CsMatrix::new_uninitialized_generic(nrows, ncols, len);
|
||||||
let mut nz = 0;
|
let mut nz = 0;
|
||||||
|
|
||||||
for j in 0..ncols.value() {
|
for j in 0..ncols.value() {
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
#![cfg_attr(rustfmt, rustfmt_skip)]
|
||||||
|
|
||||||
|
use na::{CsMatrix, CsVector, Matrix5, Vector5};
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cs_lower_triangular_solve() {
|
||||||
|
let a = Matrix5::new(
|
||||||
|
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||||
|
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||||
|
9.0, 10.0, 11.0, 12.0, 0.0,
|
||||||
|
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0, -10.0
|
||||||
|
);
|
||||||
|
let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
|
||||||
|
|
||||||
|
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||||
|
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular(&b), a.solve_lower_triangular(&b));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cs_tr_lower_triangular_solve() {
|
||||||
|
let a = Matrix5::new(
|
||||||
|
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||||
|
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||||
|
9.0, 10.0, 11.0, 12.0, 0.0,
|
||||||
|
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0, -10.0
|
||||||
|
);
|
||||||
|
let b = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
|
||||||
|
|
||||||
|
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||||
|
|
||||||
|
assert!(cs_a.tr_solve_lower_triangular(&b).is_some());
|
||||||
|
assert_eq!(cs_a.tr_solve_lower_triangular(&b), a.tr_solve_lower_triangular(&b));
|
||||||
|
|
||||||
|
// Singular case.
|
||||||
|
let a = Matrix5::new(
|
||||||
|
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||||
|
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||||
|
9.0, 10.0, 0.0, 12.0, 0.0,
|
||||||
|
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0, -10.0
|
||||||
|
);
|
||||||
|
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||||
|
|
||||||
|
assert!(cs_a.tr_solve_lower_triangular(&b).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cs_lower_triangular_solve_cs() {
|
||||||
|
let a = Matrix5::new(
|
||||||
|
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||||
|
5.0, 6.0, 0.0, 8.0, 10.0,
|
||||||
|
9.0, 10.0, 11.0, 12.0, 0.0,
|
||||||
|
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0, -10.0
|
||||||
|
);
|
||||||
|
let b1 = Vector5::zeros();
|
||||||
|
let b2 = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0);
|
||||||
|
let b3 = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0);
|
||||||
|
let b4 = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0);
|
||||||
|
let b5 = Vector5::x();
|
||||||
|
let b6 = Vector5::y();
|
||||||
|
let b7 = Vector5::z();
|
||||||
|
let b8 = Vector5::w();
|
||||||
|
let b9 = Vector5::a();
|
||||||
|
|
||||||
|
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||||
|
let cs_b1: CsVector<_, _> = Vector5::zeros().into();
|
||||||
|
let cs_b2: CsVector<_, _> = Vector5::new(1.0, 2.0, 3.0, 4.0, 5.0).into();
|
||||||
|
let cs_b3: CsVector<_, _> = Vector5::new(1.0, 0.0, 0.0, 4.0, 0.0).into();
|
||||||
|
let cs_b4: CsVector<_, _> = Vector5::new(0.0, 1.0, 0.0, 4.0, 5.0).into();
|
||||||
|
let cs_b5: CsVector<_, _> = Vector5::x().into();
|
||||||
|
let cs_b6: CsVector<_, _> = Vector5::y().into();
|
||||||
|
let cs_b7: CsVector<_, _> = Vector5::z().into();
|
||||||
|
let cs_b8: CsVector<_, _> = Vector5::w().into();
|
||||||
|
let cs_b9: CsVector<_, _> = Vector5::a().into();
|
||||||
|
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3));
|
||||||
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4));
|
||||||
|
|
||||||
|
|
||||||
|
// Singular case.
|
||||||
|
let a = Matrix5::new(
|
||||||
|
4.0, 1.0, 4.0, 0.0, 9.0,
|
||||||
|
5.0, 0.0, 0.0, 8.0, 10.0,
|
||||||
|
9.0, 10.0, 0.0, 12.0, 0.0,
|
||||||
|
0.0, -8.0, 3.0, 5.0, 9.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0, -10.0
|
||||||
|
);
|
||||||
|
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||||
|
|
||||||
|
assert!(cs_a.solve_lower_triangular_cs(&cs_b2).is_none());
|
||||||
|
assert!(cs_a.solve_lower_triangular_cs(&cs_b3).is_none());
|
||||||
|
assert!(cs_a.solve_lower_triangular_cs(&cs_b4).is_none());
|
||||||
|
}
|
|
@ -2,3 +2,4 @@ mod cs_construction;
|
||||||
mod cs_conversion;
|
mod cs_conversion;
|
||||||
mod cs_matrix;
|
mod cs_matrix;
|
||||||
mod cs_ops;
|
mod cs_ops;
|
||||||
|
mod cs_linalg;
|
Loading…
Reference in New Issue