Ensure the output of multiplication and triangular solve are sorted.

This commit is contained in:
sebcrozet 2018-11-05 16:38:43 +01:00
parent c3e8112d5e
commit 748cfeea66
9 changed files with 173 additions and 27 deletions

View File

@ -246,6 +246,32 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
nrows.value() == ncols.value() nrows.value() == ncols.value()
} }
/// Should always return `true`.
///
/// This method is generally used for debugging and should typically not be called in user code.
/// This checks that the row inner indices of this matrix are sorted. It takes `O(n)` time,
/// where n` is `self.len()`.
/// All operations of CSC matrices on nalgebra assume, and will return, sorted indices.
/// If at any time this `is_sorted` method returns `false`, then, something went wrong
/// and an issue should be open on the nalgebra repository with details on how to reproduce
/// this.
pub fn is_sorted(&self) -> bool {
for j in 0..self.ncols() {
let mut curr = None;
for idx in self.data.column_row_indices(j) {
if let Some(curr) = curr {
if idx <= curr {
return false;
}
}
curr = Some(idx);
}
}
true
}
pub fn transpose(&self) -> CsMatrix<N, C, R> pub fn transpose(&self) -> CsMatrix<N, C, R>
where where
DefaultAllocator: Allocator<usize, R>, DefaultAllocator: Allocator<usize, R>,
@ -278,3 +304,41 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
res res
} }
} }
impl<N: Scalar, R: Dim, C: Dim> CsMatrix<N, R, C>
where
DefaultAllocator: Allocator<usize, C>,
{
pub(crate) fn sort(&mut self)
where
DefaultAllocator: Allocator<N, R>,
{
// Size = R
let nrows = self.data.shape().0;
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows, U1) };
self.sort_with_workspace(workspace.as_mut_slice());
}
pub(crate) fn sort_with_workspace(&mut self, workspace: &mut [N]) {
assert!(
workspace.len() >= self.nrows(),
"Workspace must be able to hold at least self.nrows() elements."
);
for j in 0..self.ncols() {
// Scatter the row in the workspace.
for (irow, val) in self.data.column_entries(j) {
workspace[irow] = val;
}
// Sort the index vector.
let range = self.data.column_range(j);
self.data.i[range.clone()].sort();
// Permute the values too.
for (i, irow) in range.clone().zip(self.data.i[range].iter().cloned()) {
self.data.vals[i] = workspace[irow];
}
}
}
}

View File

@ -155,7 +155,7 @@ where
} }
// Performs the numerical Cholesky decomposition given the set of numerical values. // Performs the numerical Cholesky decomposition given the set of numerical values.
pub fn decompose(&mut self, values: &[N]) -> bool { pub fn decompose_up_looking(&mut self, values: &[N]) -> bool {
assert!( assert!(
values.len() >= self.original_i.len(), values.len() >= self.original_i.len(),
"The set of values is too small." "The set of values is too small."

View File

@ -172,7 +172,11 @@ where
); );
} }
for p in res.data.p[j]..nz { // Keep the output sorted.
let range = res.data.p[j]..nz;
res.data.i[range.clone()].sort();
for p in range {
res.data.vals[p] = workspace[res.data.i[p]] res.data.vals[p] = workspace[res.data.i[p]]
} }
} }

View File

@ -145,7 +145,10 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
ShapeConstraint: SameNumberOfRows<D, D2>, ShapeConstraint: SameNumberOfRows<D, D2>,
{ {
let mut reach = Vec::new(); let mut reach = Vec::new();
// We don't compute a postordered reach here because it will be sorted after anyway.
self.lower_triangular_reach(b, &mut reach); self.lower_triangular_reach(b, &mut reach);
// We sort the reach so the result matrix has sorted indices.
reach.sort();
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) }; let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
for i in reach.iter().cloned() { for i in reach.iter().cloned() {
@ -156,7 +159,7 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
workspace[i] = val; workspace[i] = val;
} }
for j in reach.iter().cloned().rev() { for j in reach.iter().cloned() {
let mut column = self.data.column_entries(j); let mut column = self.data.column_entries(j);
let mut diag_found = false; let mut diag_found = false;
@ -192,8 +195,12 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
Some(result) Some(result)
} }
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>) // Computes the reachable, post-ordered, nodes from `b`.
where fn lower_triangular_reach_postordered<D2: Dim, S2>(
&self,
b: &CsVector<N, D2, S2>,
xi: &mut Vec<usize>,
) where
S2: CsStorage<N, D2>, S2: CsStorage<N, D2>,
DefaultAllocator: Allocator<bool, D>, DefaultAllocator: Allocator<bool, D>,
{ {
@ -232,4 +239,43 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
xi.push(j) xi.push(j)
} }
} }
// Computes the nodes reachable from `b` in an arbitrary order.
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
where
S2: CsStorage<N, D2>,
DefaultAllocator: Allocator<bool, D>,
{
let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false);
let mut stack = Vec::new();
for irow in b.data.column_row_indices(0) {
self.lower_triangular_bfs(irow, visited.as_mut_slice(), &mut stack, xi);
}
}
fn lower_triangular_bfs(
&self,
start: usize,
visited: &mut [bool],
stack: &mut Vec<usize>,
xi: &mut Vec<usize>,
) {
if !visited[start] {
stack.clear();
stack.push(start);
xi.push(start);
visited[start] = true;
while let Some(j) = stack.pop() {
for irow in self.data.column_row_indices(j) {
if irow > j && !visited[irow] {
stack.push(irow);
xi.push(irow);
visited[irow] = true;
}
}
}
}
}
} }

View File

@ -35,21 +35,41 @@ fn cs_cholesky() {
1.0, 1.0, 0.0, 0.0, 2.0 1.0, 1.0, 0.0, 0.0, 2.0
); );
a.fill_upper_triangle_with_lower_triangle(); a.fill_upper_triangle_with_lower_triangle();
// Test ::new, left_looking, and up_looking implementations.
test_cholesky(a); test_cholesky(a);
} }
fn test_cholesky(a: Matrix5<f32>) { fn test_cholesky(a: Matrix5<f32>) {
// Test ::new
test_cholesky_variant(a, 0);
// Test up-looking
test_cholesky_variant(a, 1);
// Test left-looking
test_cholesky_variant(a, 2);
}
fn test_cholesky_variant(a: Matrix5<f32>, option: usize) {
let cs_a: CsMatrix<_, _, _> = a.into(); let cs_a: CsMatrix<_, _, _> = a.into();
let chol_a = Cholesky::new(a).unwrap(); let chol_a = Cholesky::new(a).unwrap();
let chol_cs_a = CsCholesky::new(&cs_a); let mut chol_cs_a;
let l = chol_a.l();
println!("{:?}", chol_cs_a.l());
let cs_l: Matrix5<_> = chol_cs_a.unwrap_l().unwrap().into();
println!("{}", l); match option {
println!("{}", cs_l); 0 => chol_cs_a = CsCholesky::new(&cs_a),
1 => {
assert_relative_eq!(l, cs_l); chol_cs_a = CsCholesky::new_symbolic(&cs_a);
chol_cs_a.decompose_up_looking(cs_a.data.values());
}
_ => {
chol_cs_a = CsCholesky::new_symbolic(&cs_a);
chol_cs_a.decompose_left_looking(cs_a.data.values());
}
};
let l = chol_a.l();
let cs_l = chol_cs_a.unwrap_l().unwrap();
assert!(cs_l.is_sorted());
let cs_l_mat: Matrix5<_> = cs_l.into();
assert_relative_eq!(l, cs_l_mat);
} }

View File

@ -12,7 +12,8 @@ fn cs_from_to_matrix() {
); );
let cs: CsMatrix<_, _, _> = m.into(); let cs: CsMatrix<_, _, _> = m.into();
let m2: Matrix4x5<_> = cs.into(); assert!(cs.is_sorted());
let m2: Matrix4x5<_> = cs.into();
assert_eq!(m2, m); assert_eq!(m2, m);
} }

View File

@ -12,7 +12,11 @@ fn cs_transpose() {
); );
let cs: CsMatrix<_, _, _> = m.into(); let cs: CsMatrix<_, _, _> = m.into();
let cs_transposed: Matrix5x4<_> = cs.transpose().into(); assert!(cs.is_sorted());
assert_eq!(cs_transposed, m.transpose()) let cs_transposed = cs.transpose();
assert!(cs_transposed.is_sorted());
let cs_transposed_mat: Matrix5x4<_> = cs_transposed.into();
assert_eq!(cs_transposed_mat, m.transpose())
} }

View File

@ -12,6 +12,7 @@ fn axpy_cs() {
let cs: CsVector<_, _> = v2.into(); let cs: CsVector<_, _> = v2.into();
v1.axpy_cs(5.0, &cs, 10.0); v1.axpy_cs(5.0, &cs, 10.0);
assert!(cs.is_sorted());
assert_eq!(v1, expected) assert_eq!(v1, expected)
} }
@ -36,6 +37,9 @@ fn cs_mat_mul() {
let mul = &sm1 * &sm2; let mul = &sm1 * &sm2;
assert!(sm1.is_sorted());
assert!(sm2.is_sorted());
assert!(mul.is_sorted());
assert_eq!(Matrix3x5::from(mul), m1 * m2); assert_eq!(Matrix3x5::from(mul), m1 * m2);
} }
@ -59,7 +63,10 @@ fn cs_mat_add() {
let sm1: CsMatrix<_, _, _> = m1.into(); let sm1: CsMatrix<_, _, _> = m1.into();
let sm2: CsMatrix<_, _, _> = m2.into(); let sm2: CsMatrix<_, _, _> = m2.into();
let mul = &sm1 + &sm2; let sum = &sm1 + &sm2;
assert_eq!(Matrix4x5::from(mul), m1 + m2); assert!(sm1.is_sorted());
assert!(sm2.is_sorted());
assert!(sum.is_sorted());
assert_eq!(Matrix4x5::from(sum), m1 + m2);
} }

View File

@ -79,15 +79,15 @@ fn cs_lower_triangular_solve_cs() {
let cs_b8: CsVector<_, _> = Vector5::w().into(); let cs_b8: CsVector<_, _> = Vector5::w().into();
let cs_b9: CsVector<_, _> = Vector5::a().into(); let cs_b9: CsVector<_, _> = Vector5::a().into();
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b1));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b5));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b6));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b7));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b8));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b9));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b2));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b3));
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4)); assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b4));
// Singular case. // Singular case.