Clean up CscCholesky
This commit is contained in:
parent
4b395523dd
commit
cd9c3baead
|
@ -4,10 +4,9 @@
|
|||
use crate::pattern::SparsityPattern;
|
||||
use crate::csc::CscMatrix;
|
||||
use core::{mem, iter};
|
||||
use nalgebra::{U1, VectorN, Dynamic, Scalar, RealField};
|
||||
use num_traits::Zero;
|
||||
use nalgebra::{Scalar, RealField};
|
||||
use std::sync::Arc;
|
||||
use std::ops::Add;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
pub struct CscSymbolicCholesky {
|
||||
// Pattern of the original matrix that was decomposed
|
||||
|
@ -21,37 +20,11 @@ impl CscSymbolicCholesky {
|
|||
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self {
|
||||
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
||||
"Major and minor dimensions must be the same (square matrix).");
|
||||
|
||||
// TODO: Temporary stopgap solution to make things work until we can refactor
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
||||
struct DummyVal;
|
||||
impl Zero for DummyVal {
|
||||
fn zero() -> Self {
|
||||
DummyVal
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<DummyVal> for DummyVal {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: DummyVal) -> Self::Output {
|
||||
rhs
|
||||
}
|
||||
}
|
||||
|
||||
let dummy_vals = vec![DummyVal; pattern.nnz()];
|
||||
let dummy_csc = CscMatrix::try_from_pattern_and_values(Arc::clone(pattern), dummy_vals)
|
||||
.unwrap();
|
||||
let (l, u) = nonzero_pattern(&dummy_csc);
|
||||
// TODO: Don't clone unnecessarily
|
||||
let (l_pattern, u_pattern) = nonzero_pattern(&*pattern);
|
||||
Self {
|
||||
m_pattern: Arc::clone(pattern),
|
||||
l_pattern: l.pattern().as_ref().clone(),
|
||||
u_pattern: u.pattern().as_ref().clone()
|
||||
l_pattern,
|
||||
u_pattern,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,10 +43,20 @@ pub struct CscCholesky<T> {
|
|||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum CholeskyError {
|
||||
|
||||
/// The matrix is not positive definite.
|
||||
NotPositiveDefinite,
|
||||
}
|
||||
|
||||
impl Display for CholeskyError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Matrix is not positive definite")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CholeskyError {}
|
||||
|
||||
impl<T: RealField> CscCholesky<T> {
|
||||
|
||||
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
||||
|
@ -181,9 +164,7 @@ impl<T: RealField> CscCholesky<T> {
|
|||
*self.work_x.get_unchecked_mut(p) = T::zero();
|
||||
}
|
||||
} else {
|
||||
// self.ok = false;
|
||||
// TODO: Return indefinite error (i.e. encountered non-positive diagonal
|
||||
unimplemented!()
|
||||
return Err(CholeskyError::NotPositiveDefinite);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -196,8 +177,8 @@ impl<T: RealField> CscCholesky<T> {
|
|||
|
||||
|
||||
|
||||
fn reach<T>(
|
||||
m: &CscMatrix<T>,
|
||||
fn reach(
|
||||
pattern: &SparsityPattern,
|
||||
j: usize,
|
||||
max_j: usize,
|
||||
tree: &[usize],
|
||||
|
@ -211,7 +192,7 @@ fn reach<T>(
|
|||
let mut tmp = Vec::new();
|
||||
let mut res = Vec::new();
|
||||
|
||||
for &irow in m.col(j).row_indices() {
|
||||
for &irow in pattern.lane(j) {
|
||||
let mut curr = irow;
|
||||
while curr != usize::max_value() && curr <= max_j && !marks[curr] {
|
||||
marks[curr] = true;
|
||||
|
@ -223,57 +204,45 @@ fn reach<T>(
|
|||
mem::swap(&mut tmp, &mut res);
|
||||
}
|
||||
|
||||
// TODO: Is this right?
|
||||
res.sort_unstable();
|
||||
|
||||
out.append(&mut res);
|
||||
}
|
||||
|
||||
fn nonzero_pattern<T: Scalar + Zero>(
|
||||
m: &CscMatrix<T>
|
||||
) -> (CscMatrix<T>, CscMatrix<T>) {
|
||||
// TODO: In order to stay as faithful as possible to the original implementation,
|
||||
// we here return full matrices, whereas we actually only need to construct sparsity patterns
|
||||
|
||||
fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
|
||||
let etree = elimination_tree(m);
|
||||
let (nrows, ncols) = (m.nrows(), m.ncols());
|
||||
// Note: We assume CSC, therefore rows == minor and cols == major
|
||||
let (nrows, ncols) = (m.minor_dim(), m.major_dim());
|
||||
let mut rows = Vec::with_capacity(m.nnz());
|
||||
// TODO: Use a Vec here instead
|
||||
let mut cols = unsafe { VectorN::new_uninitialized_generic(Dynamic::new(nrows), U1) };
|
||||
let mut col_offsets = Vec::with_capacity(ncols + 1);
|
||||
let mut marks = Vec::new();
|
||||
|
||||
// NOTE: the following will actually compute the non-zero pattern of
|
||||
// the transpose of l.
|
||||
col_offsets.push(0);
|
||||
for i in 0..nrows {
|
||||
cols[i] = rows.len();
|
||||
reach(m, i, i, &etree, &mut marks, &mut rows);
|
||||
col_offsets.push(rows.len());
|
||||
}
|
||||
|
||||
// TODO: Get rid of this in particular
|
||||
let mut vals = Vec::with_capacity(rows.len());
|
||||
unsafe {
|
||||
vals.set_len(rows.len());
|
||||
}
|
||||
vals.shrink_to_fit();
|
||||
let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows)
|
||||
.unwrap();
|
||||
|
||||
// TODO: Remove this unnecessary conversion by using Vec throughout
|
||||
let mut cols: Vec<_> = cols.iter().cloned().collect();
|
||||
cols.push(rows.len());
|
||||
// TODO: Avoid this transpose?
|
||||
let l_pattern = u_pattern.transpose();
|
||||
|
||||
let u = CscMatrix::try_from_csc_data(nrows, ncols, cols, rows, vals).unwrap();
|
||||
// TODO: Avoid this transpose
|
||||
let l = u.transpose();
|
||||
|
||||
(l, u)
|
||||
(l_pattern, u_pattern)
|
||||
}
|
||||
|
||||
fn elimination_tree<T>(m: &CscMatrix<T>) -> Vec<usize> {
|
||||
let nrows = m.nrows();
|
||||
fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
|
||||
// Note: The pattern is assumed to of a CSC matrix, so the number of rows is
|
||||
// given by the minor dimension
|
||||
let nrows = pattern.minor_dim();
|
||||
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||
|
||||
for k in 0..nrows {
|
||||
for &irow in m.col(k).row_indices() {
|
||||
for &irow in pattern.lane(k) {
|
||||
let mut i = irow;
|
||||
|
||||
while i < k {
|
||||
|
|
Loading…
Reference in New Issue