Clean up CscCholesky

This commit is contained in:
Andreas Longva 2021-01-19 15:20:01 +01:00
parent 4b395523dd
commit cd9c3baead

View File

@ -4,10 +4,9 @@
use crate::pattern::SparsityPattern; use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use core::{mem, iter}; use core::{mem, iter};
use nalgebra::{U1, VectorN, Dynamic, Scalar, RealField}; use nalgebra::{Scalar, RealField};
use num_traits::Zero;
use std::sync::Arc; use std::sync::Arc;
use std::ops::Add; use std::fmt::{Display, Formatter};
pub struct CscSymbolicCholesky { pub struct CscSymbolicCholesky {
// Pattern of the original matrix that was decomposed // Pattern of the original matrix that was decomposed
@ -21,37 +20,11 @@ impl CscSymbolicCholesky {
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self { pub fn factor(pattern: &Arc<SparsityPattern>) -> Self {
assert_eq!(pattern.major_dim(), pattern.minor_dim(), assert_eq!(pattern.major_dim(), pattern.minor_dim(),
"Major and minor dimensions must be the same (square matrix)."); "Major and minor dimensions must be the same (square matrix).");
let (l_pattern, u_pattern) = nonzero_pattern(&*pattern);
// TODO: Temporary stopgap solution to make things work until we can refactor
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
struct DummyVal;
impl Zero for DummyVal {
fn zero() -> Self {
DummyVal
}
fn is_zero(&self) -> bool {
true
}
}
impl Add<DummyVal> for DummyVal {
type Output = Self;
fn add(self, rhs: DummyVal) -> Self::Output {
rhs
}
}
let dummy_vals = vec![DummyVal; pattern.nnz()];
let dummy_csc = CscMatrix::try_from_pattern_and_values(Arc::clone(pattern), dummy_vals)
.unwrap();
let (l, u) = nonzero_pattern(&dummy_csc);
// TODO: Don't clone unnecessarily
Self { Self {
m_pattern: Arc::clone(pattern), m_pattern: Arc::clone(pattern),
l_pattern: l.pattern().as_ref().clone(), l_pattern,
u_pattern: u.pattern().as_ref().clone() u_pattern,
} }
} }
@ -70,10 +43,20 @@ pub struct CscCholesky<T> {
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone)]
#[non_exhaustive]
pub enum CholeskyError { pub enum CholeskyError {
/// The matrix is not positive definite.
NotPositiveDefinite,
} }
impl Display for CholeskyError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Matrix is not positive definite")
}
}
impl std::error::Error for CholeskyError {}
impl<T: RealField> CscCholesky<T> { impl<T: RealField> CscCholesky<T> {
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> { pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
@ -181,9 +164,7 @@ impl<T: RealField> CscCholesky<T> {
*self.work_x.get_unchecked_mut(p) = T::zero(); *self.work_x.get_unchecked_mut(p) = T::zero();
} }
} else { } else {
// self.ok = false; return Err(CholeskyError::NotPositiveDefinite);
// TODO: Return indefinite error (i.e. encountered non-positive diagonal
unimplemented!()
} }
} }
} }
@ -196,8 +177,8 @@ impl<T: RealField> CscCholesky<T> {
fn reach<T>( fn reach(
m: &CscMatrix<T>, pattern: &SparsityPattern,
j: usize, j: usize,
max_j: usize, max_j: usize,
tree: &[usize], tree: &[usize],
@ -211,7 +192,7 @@ fn reach<T>(
let mut tmp = Vec::new(); let mut tmp = Vec::new();
let mut res = Vec::new(); let mut res = Vec::new();
for &irow in m.col(j).row_indices() { for &irow in pattern.lane(j) {
let mut curr = irow; let mut curr = irow;
while curr != usize::max_value() && curr <= max_j && !marks[curr] { while curr != usize::max_value() && curr <= max_j && !marks[curr] {
marks[curr] = true; marks[curr] = true;
@ -223,57 +204,45 @@ fn reach<T>(
mem::swap(&mut tmp, &mut res); mem::swap(&mut tmp, &mut res);
} }
// TODO: Is this right?
res.sort_unstable(); res.sort_unstable();
out.append(&mut res); out.append(&mut res);
} }
fn nonzero_pattern<T: Scalar + Zero>( fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
m: &CscMatrix<T>
) -> (CscMatrix<T>, CscMatrix<T>) {
// TODO: In order to stay as faithful as possible to the original implementation,
// we here return full matrices, whereas we actually only need to construct sparsity patterns
let etree = elimination_tree(m); let etree = elimination_tree(m);
let (nrows, ncols) = (m.nrows(), m.ncols()); // Note: We assume CSC, therefore rows == minor and cols == major
let (nrows, ncols) = (m.minor_dim(), m.major_dim());
let mut rows = Vec::with_capacity(m.nnz()); let mut rows = Vec::with_capacity(m.nnz());
// TODO: Use a Vec here instead let mut col_offsets = Vec::with_capacity(ncols + 1);
let mut cols = unsafe { VectorN::new_uninitialized_generic(Dynamic::new(nrows), U1) };
let mut marks = Vec::new(); let mut marks = Vec::new();
// NOTE: the following will actually compute the non-zero pattern of // NOTE: the following will actually compute the non-zero pattern of
// the transpose of l. // the transpose of l.
col_offsets.push(0);
for i in 0..nrows { for i in 0..nrows {
cols[i] = rows.len();
reach(m, i, i, &etree, &mut marks, &mut rows); reach(m, i, i, &etree, &mut marks, &mut rows);
col_offsets.push(rows.len());
} }
// TODO: Get rid of this in particular let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows)
let mut vals = Vec::with_capacity(rows.len()); .unwrap();
unsafe {
vals.set_len(rows.len());
}
vals.shrink_to_fit();
// TODO: Remove this unnecessary conversion by using Vec throughout // TODO: Avoid this transpose?
let mut cols: Vec<_> = cols.iter().cloned().collect(); let l_pattern = u_pattern.transpose();
cols.push(rows.len());
let u = CscMatrix::try_from_csc_data(nrows, ncols, cols, rows, vals).unwrap(); (l_pattern, u_pattern)
// TODO: Avoid this transpose
let l = u.transpose();
(l, u)
} }
fn elimination_tree<T>(m: &CscMatrix<T>) -> Vec<usize> { fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
let nrows = m.nrows(); // Note: The pattern is assumed to of a CSC matrix, so the number of rows is
// given by the minor dimension
let nrows = pattern.minor_dim();
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect(); let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect(); let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
for k in 0..nrows { for k in 0..nrows {
for &irow in m.col(k).row_indices() { for &irow in pattern.lane(k) {
let mut i = irow; let mut i = irow;
while i < k { while i < k {