forked from M-Labs/nalgebra
Ensure the output of multiplication and triangular solve are sorted.
This commit is contained in:
parent
c3e8112d5e
commit
748cfeea66
@ -246,6 +246,32 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||
nrows.value() == ncols.value()
|
||||
}
|
||||
|
||||
/// Should always return `true`.
|
||||
///
|
||||
/// This method is generally used for debugging and should typically not be called in user code.
|
||||
/// This checks that the row inner indices of this matrix are sorted. It takes `O(n)` time,
|
||||
/// where n` is `self.len()`.
|
||||
/// All operations of CSC matrices on nalgebra assume, and will return, sorted indices.
|
||||
/// If at any time this `is_sorted` method returns `false`, then, something went wrong
|
||||
/// and an issue should be open on the nalgebra repository with details on how to reproduce
|
||||
/// this.
|
||||
pub fn is_sorted(&self) -> bool {
|
||||
for j in 0..self.ncols() {
|
||||
let mut curr = None;
|
||||
for idx in self.data.column_row_indices(j) {
|
||||
if let Some(curr) = curr {
|
||||
if idx <= curr {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
curr = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn transpose(&self) -> CsMatrix<N, C, R>
|
||||
where
|
||||
DefaultAllocator: Allocator<usize, R>,
|
||||
@ -278,3 +304,41 @@ impl<N: Scalar, R: Dim, C: Dim, S: CsStorage<N, R, C>> CsMatrix<N, R, C, S> {
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Scalar, R: Dim, C: Dim> CsMatrix<N, R, C>
|
||||
where
|
||||
DefaultAllocator: Allocator<usize, C>,
|
||||
{
|
||||
pub(crate) fn sort(&mut self)
|
||||
where
|
||||
DefaultAllocator: Allocator<N, R>,
|
||||
{
|
||||
// Size = R
|
||||
let nrows = self.data.shape().0;
|
||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(nrows, U1) };
|
||||
self.sort_with_workspace(workspace.as_mut_slice());
|
||||
}
|
||||
|
||||
pub(crate) fn sort_with_workspace(&mut self, workspace: &mut [N]) {
|
||||
assert!(
|
||||
workspace.len() >= self.nrows(),
|
||||
"Workspace must be able to hold at least self.nrows() elements."
|
||||
);
|
||||
|
||||
for j in 0..self.ncols() {
|
||||
// Scatter the row in the workspace.
|
||||
for (irow, val) in self.data.column_entries(j) {
|
||||
workspace[irow] = val;
|
||||
}
|
||||
|
||||
// Sort the index vector.
|
||||
let range = self.data.column_range(j);
|
||||
self.data.i[range.clone()].sort();
|
||||
|
||||
// Permute the values too.
|
||||
for (i, irow) in range.clone().zip(self.data.i[range].iter().cloned()) {
|
||||
self.data.vals[i] = workspace[irow];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -155,7 +155,7 @@ where
|
||||
}
|
||||
|
||||
// Performs the numerical Cholesky decomposition given the set of numerical values.
|
||||
pub fn decompose(&mut self, values: &[N]) -> bool {
|
||||
pub fn decompose_up_looking(&mut self, values: &[N]) -> bool {
|
||||
assert!(
|
||||
values.len() >= self.original_i.len(),
|
||||
"The set of values is too small."
|
||||
|
@ -172,7 +172,11 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
for p in res.data.p[j]..nz {
|
||||
// Keep the output sorted.
|
||||
let range = res.data.p[j]..nz;
|
||||
res.data.i[range.clone()].sort();
|
||||
|
||||
for p in range {
|
||||
res.data.vals[p] = workspace[res.data.i[p]]
|
||||
}
|
||||
}
|
||||
|
@ -145,7 +145,10 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||
ShapeConstraint: SameNumberOfRows<D, D2>,
|
||||
{
|
||||
let mut reach = Vec::new();
|
||||
// We don't compute a postordered reach here because it will be sorted after anyway.
|
||||
self.lower_triangular_reach(b, &mut reach);
|
||||
// We sort the reach so the result matrix has sorted indices.
|
||||
reach.sort();
|
||||
let mut workspace = unsafe { VectorN::new_uninitialized_generic(b.data.shape().0, U1) };
|
||||
|
||||
for i in reach.iter().cloned() {
|
||||
@ -156,7 +159,7 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||
workspace[i] = val;
|
||||
}
|
||||
|
||||
for j in reach.iter().cloned().rev() {
|
||||
for j in reach.iter().cloned() {
|
||||
let mut column = self.data.column_entries(j);
|
||||
let mut diag_found = false;
|
||||
|
||||
@ -192,8 +195,12 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||
Some(result)
|
||||
}
|
||||
|
||||
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
|
||||
where
|
||||
// Computes the reachable, post-ordered, nodes from `b`.
|
||||
fn lower_triangular_reach_postordered<D2: Dim, S2>(
|
||||
&self,
|
||||
b: &CsVector<N, D2, S2>,
|
||||
xi: &mut Vec<usize>,
|
||||
) where
|
||||
S2: CsStorage<N, D2>,
|
||||
DefaultAllocator: Allocator<bool, D>,
|
||||
{
|
||||
@ -232,4 +239,43 @@ impl<N: Real, D: Dim, S: CsStorage<N, D, D>> CsMatrix<N, D, D, S> {
|
||||
xi.push(j)
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the nodes reachable from `b` in an arbitrary order.
|
||||
fn lower_triangular_reach<D2: Dim, S2>(&self, b: &CsVector<N, D2, S2>, xi: &mut Vec<usize>)
|
||||
where
|
||||
S2: CsStorage<N, D2>,
|
||||
DefaultAllocator: Allocator<bool, D>,
|
||||
{
|
||||
let mut visited = VectorN::repeat_generic(self.data.shape().1, U1, false);
|
||||
let mut stack = Vec::new();
|
||||
|
||||
for irow in b.data.column_row_indices(0) {
|
||||
self.lower_triangular_bfs(irow, visited.as_mut_slice(), &mut stack, xi);
|
||||
}
|
||||
}
|
||||
|
||||
fn lower_triangular_bfs(
|
||||
&self,
|
||||
start: usize,
|
||||
visited: &mut [bool],
|
||||
stack: &mut Vec<usize>,
|
||||
xi: &mut Vec<usize>,
|
||||
) {
|
||||
if !visited[start] {
|
||||
stack.clear();
|
||||
stack.push(start);
|
||||
xi.push(start);
|
||||
visited[start] = true;
|
||||
|
||||
while let Some(j) = stack.pop() {
|
||||
for irow in self.data.column_row_indices(j) {
|
||||
if irow > j && !visited[irow] {
|
||||
stack.push(irow);
|
||||
xi.push(irow);
|
||||
visited[irow] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -35,21 +35,41 @@ fn cs_cholesky() {
|
||||
1.0, 1.0, 0.0, 0.0, 2.0
|
||||
);
|
||||
a.fill_upper_triangle_with_lower_triangle();
|
||||
// Test ::new, left_looking, and up_looking implementations.
|
||||
test_cholesky(a);
|
||||
}
|
||||
|
||||
|
||||
fn test_cholesky(a: Matrix5<f32>) {
|
||||
// Test ::new
|
||||
test_cholesky_variant(a, 0);
|
||||
// Test up-looking
|
||||
test_cholesky_variant(a, 1);
|
||||
// Test left-looking
|
||||
test_cholesky_variant(a, 2);
|
||||
}
|
||||
|
||||
fn test_cholesky_variant(a: Matrix5<f32>, option: usize) {
|
||||
let cs_a: CsMatrix<_, _, _> = a.into();
|
||||
|
||||
let chol_a = Cholesky::new(a).unwrap();
|
||||
let chol_cs_a = CsCholesky::new(&cs_a);
|
||||
let mut chol_cs_a;
|
||||
|
||||
match option {
|
||||
0 => chol_cs_a = CsCholesky::new(&cs_a),
|
||||
1 => {
|
||||
chol_cs_a = CsCholesky::new_symbolic(&cs_a);
|
||||
chol_cs_a.decompose_up_looking(cs_a.data.values());
|
||||
}
|
||||
_ => {
|
||||
chol_cs_a = CsCholesky::new_symbolic(&cs_a);
|
||||
chol_cs_a.decompose_left_looking(cs_a.data.values());
|
||||
}
|
||||
};
|
||||
|
||||
let l = chol_a.l();
|
||||
println!("{:?}", chol_cs_a.l());
|
||||
let cs_l: Matrix5<_> = chol_cs_a.unwrap_l().unwrap().into();
|
||||
let cs_l = chol_cs_a.unwrap_l().unwrap();
|
||||
assert!(cs_l.is_sorted());
|
||||
|
||||
println!("{}", l);
|
||||
println!("{}", cs_l);
|
||||
|
||||
assert_relative_eq!(l, cs_l);
|
||||
let cs_l_mat: Matrix5<_> = cs_l.into();
|
||||
assert_relative_eq!(l, cs_l_mat);
|
||||
}
|
||||
|
@ -12,7 +12,8 @@ fn cs_from_to_matrix() {
|
||||
);
|
||||
|
||||
let cs: CsMatrix<_, _, _> = m.into();
|
||||
let m2: Matrix4x5<_> = cs.into();
|
||||
assert!(cs.is_sorted());
|
||||
|
||||
let m2: Matrix4x5<_> = cs.into();
|
||||
assert_eq!(m2, m);
|
||||
}
|
||||
|
@ -12,7 +12,11 @@ fn cs_transpose() {
|
||||
);
|
||||
|
||||
let cs: CsMatrix<_, _, _> = m.into();
|
||||
let cs_transposed: Matrix5x4<_> = cs.transpose().into();
|
||||
assert!(cs.is_sorted());
|
||||
|
||||
assert_eq!(cs_transposed, m.transpose())
|
||||
let cs_transposed = cs.transpose();
|
||||
assert!(cs_transposed.is_sorted());
|
||||
|
||||
let cs_transposed_mat: Matrix5x4<_> = cs_transposed.into();
|
||||
assert_eq!(cs_transposed_mat, m.transpose())
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ fn axpy_cs() {
|
||||
let cs: CsVector<_, _> = v2.into();
|
||||
v1.axpy_cs(5.0, &cs, 10.0);
|
||||
|
||||
assert!(cs.is_sorted());
|
||||
assert_eq!(v1, expected)
|
||||
}
|
||||
|
||||
@ -36,6 +37,9 @@ fn cs_mat_mul() {
|
||||
|
||||
let mul = &sm1 * &sm2;
|
||||
|
||||
assert!(sm1.is_sorted());
|
||||
assert!(sm2.is_sorted());
|
||||
assert!(mul.is_sorted());
|
||||
assert_eq!(Matrix3x5::from(mul), m1 * m2);
|
||||
}
|
||||
|
||||
@ -59,7 +63,10 @@ fn cs_mat_add() {
|
||||
let sm1: CsMatrix<_, _, _> = m1.into();
|
||||
let sm2: CsMatrix<_, _, _> = m2.into();
|
||||
|
||||
let mul = &sm1 + &sm2;
|
||||
let sum = &sm1 + &sm2;
|
||||
|
||||
assert_eq!(Matrix4x5::from(mul), m1 + m2);
|
||||
assert!(sm1.is_sorted());
|
||||
assert!(sm2.is_sorted());
|
||||
assert!(sum.is_sorted());
|
||||
assert_eq!(Matrix4x5::from(sum), m1 + m2);
|
||||
}
|
||||
|
@ -79,15 +79,15 @@ fn cs_lower_triangular_solve_cs() {
|
||||
let cs_b8: CsVector<_, _> = Vector5::w().into();
|
||||
let cs_b9: CsVector<_, _> = Vector5::a().into();
|
||||
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| v.into()), a.solve_lower_triangular(&b1));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| v.into()), a.solve_lower_triangular(&b5));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| v.into()), a.solve_lower_triangular(&b6));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| v.into()), a.solve_lower_triangular(&b7));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| v.into()), a.solve_lower_triangular(&b8));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| v.into()), a.solve_lower_triangular(&b9));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| v.into()), a.solve_lower_triangular(&b2));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| v.into()), a.solve_lower_triangular(&b3));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| v.into()), a.solve_lower_triangular(&b4));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b1).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b1));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b5).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b5));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b6).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b6));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b7).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b7));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b8).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b8));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b9).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b9));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b2).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b2));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b3).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b3));
|
||||
assert_eq!(cs_a.solve_lower_triangular_cs(&cs_b4).map(|v| { assert!(v.is_sorted()); v.into() }), a.solve_lower_triangular(&b4));
|
||||
|
||||
|
||||
// Singular case.
|
||||
|
Loading…
Reference in New Issue
Block a user