forked from M-Labs/nalgebra
Ensure the output of multiplication and triangular solve are sorted.
This commit is contained in:
parent
c3e8112d5e
commit
748cfeea66
@ -246,6 +246,32 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
|||||||
nrows.value() == ncols.value()
|
nrows.value() == ncols.value()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Should always return `true`.
|
||||||
|
///
|
||||||
|
/// This method is generally used for debugging and should typically not be called in user code.
|
||||||
|
/// This checks that the row inner indices of this matrix are sorted. It takes `O(n)` time,
|
||||||
|
/// where n` is `self.len()`.
|
||||||
|
/// All operations of CSC matrices on nalgebra assume, and will return, sorted indices.
|
||||||
|
/// If at any time this `is_sorted` method returns `false`, then, something went wrong
|
||||||
|
/// and an issue should be open on the nalgebra repository with details on how to reproduce
|
||||||
|
/// this.
|
||||||
|
pub fn is_sorted(&self) -> bool {
|
||||||
|
for j in 0..self.ncols() {
|
||||||
|
let mut curr = None;
|
||||||
|
for idx in self.data.column_row_indices(j) {
|
||||||
|
if let Some(curr) = curr {
|
||||||
|
if idx <= curr {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
curr = Some(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
pub fn transpose(&self) -> CsMatrix<N, C, R>
|
pub fn transpose(&self) -> CsMatrix<N, C, R>
|
||||||
where
|
where
|
||||||
DefaultAllocator: Allocator<usize, R>,
|
DefaultAllocator: Allocator<usize, R>,
|
||||||
@ -278,3 +304,41 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
|||||||
res
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<N: Scalar, R: Dim, C: Dim> CsMatrix<N, R, C>
|
||||||
|
where
|
||||||
|
DefaultAllocator: Allocator<usize, C>,
|
||||||
|
{
|
||||||
|
pub(crate) fn sort(&mut self)
|
||||||
|
where
|
||||||
|
DefaultAllocator: Allocator<N, R>,
|
||||||
|
{
|
||||||
|
// Size = R
|
||||||
|
let nrows = self.data.shape().0;
|
||||||
|
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows, U1) };
|
||||||
|
self.sort_with_workspace(workspace.as_mut_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sort_with_workspace(&mut self, workspace: &mut [N]) {
|
||||||
|
assert!(
|
||||||
|
workspace.len() >= self.nrows(),
|
||||||
|
"Workspace must be able to hold at least self.nrows() elements."
|
||||||
|
);
|
||||||
|
|
||||||
|
for j in 0..self.ncols() {
|
||||||
|
// Scatter the row in the workspace.
|
||||||
|
for (irow, val) in self.data.column_entries(j) {
|
||||||
|
workspace[irow] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the index vector.
|
||||||
|
let range = self.data.column_range(j);
|
||||||
|
self.data.i[range.clone()].sort();
|
||||||
|
|
||||||
|
// Permute the values too.
|
||||||
|
for (i, irow) in range.clone().zip(self.data.i[range].iter().cloned()) {
|
||||||
|
self.data.vals[i] = workspace[irow];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -155,7 +155,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Performs the numerical Cholesky decomposition given the set of numerical values.
|
// Performs the numerical Cholesky decomposition given the set of numerical values.
|
||||||
pub fn decompose(&mut self, values: &[N]) -> bool {
|
pub fn decompose_up_looking(&mut self, values: &[N]) -> bool {
|
||||||
assert!(
|
assert!(
|
||||||
values.len() >= self.original_i.len(),
|
values.len() >= self.original_i.len(),
|
||||||
"The set of values is too small."
|
"The set of values is too small."
|
||||||
|
@ -172,7 +172,11 @@ where
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for p in res.data.p[j]..nz {
|
// Keep the output sorted.
|
||||||
|
let range = res.data.p[j]..nz;
|
||||||
|
res.data.i[range.clone()].sort();
|
||||||
|
|
||||||
|
for p in range {
|
||||||
res.data.vals[p] = workspace[res.data.i[p]]
|
res.data.vals[p] = workspace[res.data.i[p]]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,10 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
|||||||
ShapeConstraint: SameNumberOfRows<D, D2>,
|
ShapeConstraint: SameNumberOfRows<D, D2>,
|
||||||
{
|
{
|
||||||
let mut reach = Vec::new();
|
let mut reach = Vec::new();
|
||||||
|
// We don't compute a postordered reach here because it will be sorted after anyway.
|
||||||
self.lower_triangular_reach(b, &mut reach);
|
self.lower_triangular_reach(b, &mut reach);
|
||||||
|
// We sort the reach so the result matrix has sorted indices.
|
||||||
|
reach.sort();
|
||||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
|
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
|
||||||
|
|
||||||
for i in reach.iter().cloned() {
|
for i in reach.iter().cloned() {
|
||||||
@ -156,7 +159,7 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
|||||||
workspace[i] = val;
|
workspace[i] = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
for j in reach.iter().cloned().rev() {
|
for j in reach.iter().cloned() {
|
||||||
let mut column = self.data.column_entries(j);
|
let mut column = self.data.column_entries(j);
|
||||||
let mut diag_found = false;
|
let mut diag_found = false;
|
||||||
|
|
||||||
@ -192,8 +195,12 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
|||||||
Some(result)
|
Some(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
|
// Computes the reachable, post-ordered, nodes from `b`.
|
||||||
where
|
fn lower_triangular_reach_postordered<D2: Dim, S2>(
|
||||||
|
&self,
|
||||||
|
b: &CsVector<N, D2, S2>,
|
||||||
|
xi: &mut Vec<usize>,
|
||||||
|
) where
|
||||||
S2: CsStorage<N, D2>,
|
S2: CsStorage<N, D2>,
|
||||||
DefaultAllocator: Allocator<bool, D>,
|
DefaultAllocator: Allocator<bool, D>,
|
||||||
{
|
{
|
||||||
@ -232,4 +239,43 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
|||||||
xi.push(j)
|
xi.push(j)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes the nodes reachable from `b` in an arbitrary order.
|
||||||
|
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
|
||||||
|
where
|
||||||
|
S2: CsStorage<N, D2>,
|
||||||
|
DefaultAllocator: Allocator<bool, D>,
|
||||||
|
{
|
||||||
|
let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false);
|
||||||
|
let mut stack = Vec::new();
|
||||||
|
|
||||||
|
for irow in b.data.column_row_indices(0) {
|
||||||
|
self.lower_triangular_bfs(irow, visited.as_mut_slice(), &mut stack, xi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lower_triangular_bfs(
|
||||||
|
&self,
|
||||||
|
start: usize,
|
||||||
|
visited: &mut [bool],
|
||||||
|
stack: &mut Vec<usize>,
|
||||||
|
xi: &mut Vec<usize>,
|
||||||
|
) {
|
||||||
|
if !visited[start] {
|
||||||
|
stack.clear();
|
||||||
|
stack.push(start);
|
||||||
|
xi.push(start);
|
||||||
|
visited[start] = true;
|
||||||
|
|
||||||
|
while let Some(j) = stack.pop() {
|
||||||
|
for irow in self.data.column_row_indices(j) {
|
||||||
|
if irow > j && !visited[irow] {
|
||||||
|
stack.push(irow);
|
||||||
|
xi.push(irow);
|
||||||
|
visited[irow] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,21 +35,41 @@ fn cs_cholesky() {
|
|||||||
1.0, 1.0, 0.0, 0.0, 2.0
|
1.0, 1.0, 0.0, 0.0, 2.0
|
||||||
);
|
);
|
||||||
a.fill_upper_triangle_with_lower_triangle();
|
a.fill_upper_triangle_with_lower_triangle();
|
||||||
|
// Test ::new, left_looking, and up_looking implementations.
|
||||||
test_cholesky(a);
|
test_cholesky(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn test_cholesky(a: Matrix5<f32>) {
|
fn test_cholesky(a: Matrix5<f32>) {
|
||||||
|
// Test ::new
|
||||||
|
test_cholesky_variant(a, 0);
|
||||||
|
// Test up-looking
|
||||||
|
test_cholesky_variant(a, 1);
|
||||||
|
// Test left-looking
|
||||||
|
test_cholesky_variant(a, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_cholesky_variant(a: Matrix5<f32>, option: usize) {
|
||||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||||
|
|
||||||
let chol_a = Cholesky::new(a).unwrap();
|
let chol_a = Cholesky::new(a).unwrap();
|
||||||
let chol_cs_a = CsCholesky::new(&cs_a);
|
let mut chol_cs_a;
|
||||||
let l = chol_a.l();
|
|
||||||
println!("{:?}", chol_cs_a.l());
|
|
||||||
let cs_l: Matrix5<_> = chol_cs_a.unwrap_l().unwrap().into();
|
|
||||||
|
|
||||||
println!("{}", l);
|
match option {
|
||||||
println!("{}", cs_l);
|
0 => chol_cs_a = CsCholesky::new(&cs_a),
|
||||||
|
1 => {
|
||||||
assert_relative_eq!(l, cs_l);
|
chol_cs_a = CsCholesky::new_symbolic(&cs_a);
|
||||||
|
chol_cs_a.decompose_up_looking(cs_a.data.values());
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
chol_cs_a = CsCholesky::new_symbolic(&cs_a);
|
||||||
|
chol_cs_a.decompose_left_looking(cs_a.data.values());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let l = chol_a.l();
|
||||||
|
let cs_l = chol_cs_a.unwrap_l().unwrap();
|
||||||
|
assert!(cs_l.is_sorted());
|
||||||
|
|
||||||
|
let cs_l_mat: Matrix5<_> = cs_l.into();
|
||||||
|
assert_relative_eq!(l, cs_l_mat);
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,8 @@ fn cs_from_to_matrix() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let cs: CsMatrix<_, _, _> = m.into();
|
let cs: CsMatrix<_, _, _> = m.into();
|
||||||
let m2: Matrix4x5<_> = cs.into();
|
assert!(cs.is_sorted());
|
||||||
|
|
||||||
|
let m2: Matrix4x5<_> = cs.into();
|
||||||
assert_eq!(m2, m);
|
assert_eq!(m2, m);
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,11 @@ fn cs_transpose() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let cs: CsMatrix<_, _, _> = m.into();
|
let cs: CsMatrix<_, _, _> = m.into();
|
||||||
let cs_transposed: Matrix5x4<_> = cs.transpose().into();
|
assert!(cs.is_sorted());
|
||||||
|
|
||||||
assert_eq!(cs_transposed, m.transpose())
|
let cs_transposed = cs.transpose();
|
||||||
|
assert!(cs_transposed.is_sorted());
|
||||||
|
|
||||||
|
let cs_transposed_mat: Matrix5x4<_> = cs_transposed.into();
|
||||||
|
assert_eq!(cs_transposed_mat, m.transpose())
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ fn axpy_cs() {
|
|||||||
let cs: CsVector<_, _> = v2.into();
|
let cs: CsVector<_, _> = v2.into();
|
||||||
v1.axpy_cs(5.0, &cs, 10.0);
|
v1.axpy_cs(5.0, &cs, 10.0);
|
||||||
|
|
||||||
|
assert!(cs.is_sorted());
|
||||||
assert_eq!(v1, expected)
|
assert_eq!(v1, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,6 +37,9 @@ fn cs_mat_mul() {
|
|||||||
|
|
||||||
let mul = &sm1 * &sm2;
|
let mul = &sm1 * &sm2;
|
||||||
|
|
||||||
|
assert!(sm1.is_sorted());
|
||||||
|
assert!(sm2.is_sorted());
|
||||||
|
assert!(mul.is_sorted());
|
||||||
assert_eq!(Matrix3x5::from(mul), m1 * m2);
|
assert_eq!(Matrix3x5::from(mul), m1 * m2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,7 +63,10 @@ fn cs_mat_add() {
|
|||||||
let sm1: CsMatrix<_, _, _> = m1.into();
|
let sm1: CsMatrix<_, _, _> = m1.into();
|
||||||
let sm2: CsMatrix<_, _, _> = m2.into();
|
let sm2: CsMatrix<_, _, _> = m2.into();
|
||||||
|
|
||||||
let mul = &sm1 + &sm2;
|
let sum = &sm1 + &sm2;
|
||||||
|
|
||||||
assert_eq!(Matrix4x5::from(mul), m1 + m2);
|
assert!(sm1.is_sorted());
|
||||||
|
assert!(sm2.is_sorted());
|
||||||
|
assert!(sum.is_sorted());
|
||||||
|
assert_eq!(Matrix4x5::from(sum), m1 + m2);
|
||||||
}
|
}
|
||||||
|
@ -79,15 +79,15 @@ fn cs_lower_triangular_solve_cs() {
|
|||||||
let cs_b8: CsVector<_, _> = Vector5::w().into();
|
let cs_b8: CsVector<_, _> = Vector5::w().into();
|
||||||
let cs_b9: CsVector<_, _> = Vector5::a().into();
|
let cs_b9: CsVector<_, _> = Vector5::a().into();
|
||||||
|
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b1));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b5));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b6));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b7));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b8));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b9));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b2));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b3));
|
||||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4));
|
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b4));
|
||||||
|
|
||||||
|
|
||||||
// Singular case.
|
// Singular case.
|
||||||
|
Loading…
Reference in New Issue
Block a user