Add lower triangular solve with dense right-hand-side.

This commit is contained in:
sebcrozet 2018-10-22 17:55:13 +02:00
parent dc8edeceb2
commit e4e5659405

View File

@ -6,9 +6,9 @@ use std::ops::{Add, Mul, Range};
use std::slice; use std::slice;
use allocator::Allocator; use allocator::Allocator;
use constraint::{AreMultipliable, DimEq, ShapeConstraint}; use constraint::{AreMultipliable, DimEq, ShapeConstraint, SameNumberOfRows};
use storage::{Storage, StorageMut}; use storage::{Storage, StorageMut};
use {DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1}; use {Real, DefaultAllocator, Dim, Matrix, MatrixMN, Scalar, Vector, VectorN, U1};
// FIXME: this structure exists for now only because impl trait // FIXME: this structure exists for now only because impl trait
// cannot be used for trait method return types. // cannot be used for trait method return types.
@ -245,6 +245,132 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
} }
} }
impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
pub fn solve_lower_triangular<R2: Dim, C2: Dim, S2>(
&self,
b: &Matrix<N, R2, C2, S2>,
) -> Option<MatrixMN<N, R2, C2>>
where
S2: Storage<N, R2, C2>,
DefaultAllocator: Allocator<N, R2, C2>,
ShapeConstraint: SameNumberOfRows<D, R2>,
{
let mut b = b.clone_owned();
if self.solve_lower_triangular_mut(&mut b) {
Some(b)
} else {
None
}
}
pub fn tr_solve_lower_triangular<R2: Dim, C2: Dim, S2>(
&self,
b: &Matrix<N, R2, C2, S2>,
) -> Option<MatrixMN<N, R2, C2>>
where
S2: Storage<N, R2, C2>,
DefaultAllocator: Allocator<N, R2, C2>,
ShapeConstraint: SameNumberOfRows<D, R2>,
{
let mut b = b.clone_owned();
if self.tr_solve_lower_triangular_mut(&mut b) {
Some(b)
} else {
None
}
}
pub fn solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
&self,
b: &mut Matrix<N, R2, C2, S2>,
) -> bool
where
S2: StorageMut<N, R2, C2>,
ShapeConstraint: SameNumberOfRows<D, R2>,
{
let (nrows, ncols) = self.data.shape();
assert_eq!(nrows.value(), ncols.value(), "The matrix must be square.");
assert_eq!(nrows.value(), b.len(), "Mismatched matrix dimensions.");
for j2 in 0..b.ncols() {
let mut b = b.column_mut(j2);
for j in 0..ncols.value() {
let mut column = self.data.column_entries(j);
let mut diag_found = false;
while let Some((i, val)) = column.next() {
if i == j {
if val.is_zero() {
return false;
}
b[j] /= val;
diag_found = true;
break;
}
}
if !diag_found {
return false;
}
for (i, val) in column {
b[i] -= b[j] * val;
}
}
}
true
}
pub fn tr_solve_lower_triangular_mut<R2: Dim, C2: Dim, S2>(
&self,
b: &mut Matrix<N, R2, C2, S2>,
) -> bool
where
S2: StorageMut<N, R2, C2>,
ShapeConstraint: SameNumberOfRows<D, R2>,
{
let (nrows, ncols) = self.data.shape();
assert_eq!(nrows.value(), ncols.value(), "The matrix must be square.");
assert_eq!(nrows.value(), b.len(), "Mismatched matrix dimensions.");
for j2 in 0..b.ncols() {
let mut b = b.column_mut(j2);
for j in (0..ncols.value()).rev() {
let mut column = self.data.column_entries(j);
let mut diag = None;
while let Some((i, val)) = column.next() {
if i == j {
if val.is_zero() {
return false;
}
diag = Some(val);
break;
}
}
if let Some(diag) = diag {
for (i, val) in column {
b[j] -= val * b[i];
}
b[j] /= diag;
} else {
return false;
}
}
}
true
}
}
/* /*
impl<N: Scalar, R, S> CsVector<N, R, S> { impl<N: Scalar, R, S> CsVector<N, R, S> {
pub fn axpy(&mut self, alpha: N, x: CsVector<N, R, S>, beta: N) { pub fn axpy(&mut self, alpha: N, x: CsVector<N, R, S>, beta: N) {