Clean up CscCholesky
This commit is contained in:
parent
4b395523dd
commit
cd9c3baead
|
@ -4,10 +4,9 @@
|
||||||
use crate::pattern::SparsityPattern;
|
use crate::pattern::SparsityPattern;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use core::{mem, iter};
|
use core::{mem, iter};
|
||||||
use nalgebra::{U1, VectorN, Dynamic, Scalar, RealField};
|
use nalgebra::{Scalar, RealField};
|
||||||
use num_traits::Zero;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::ops::Add;
|
use std::fmt::{Display, Formatter};
|
||||||
|
|
||||||
pub struct CscSymbolicCholesky {
|
pub struct CscSymbolicCholesky {
|
||||||
// Pattern of the original matrix that was decomposed
|
// Pattern of the original matrix that was decomposed
|
||||||
|
@ -21,37 +20,11 @@ impl CscSymbolicCholesky {
|
||||||
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self {
|
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self {
|
||||||
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
||||||
"Major and minor dimensions must be the same (square matrix).");
|
"Major and minor dimensions must be the same (square matrix).");
|
||||||
|
let (l_pattern, u_pattern) = nonzero_pattern(&*pattern);
|
||||||
// TODO: Temporary stopgap solution to make things work until we can refactor
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
|
||||||
struct DummyVal;
|
|
||||||
impl Zero for DummyVal {
|
|
||||||
fn zero() -> Self {
|
|
||||||
DummyVal
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_zero(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Add<DummyVal> for DummyVal {
|
|
||||||
type Output = Self;
|
|
||||||
|
|
||||||
fn add(self, rhs: DummyVal) -> Self::Output {
|
|
||||||
rhs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let dummy_vals = vec![DummyVal; pattern.nnz()];
|
|
||||||
let dummy_csc = CscMatrix::try_from_pattern_and_values(Arc::clone(pattern), dummy_vals)
|
|
||||||
.unwrap();
|
|
||||||
let (l, u) = nonzero_pattern(&dummy_csc);
|
|
||||||
// TODO: Don't clone unnecessarily
|
|
||||||
Self {
|
Self {
|
||||||
m_pattern: Arc::clone(pattern),
|
m_pattern: Arc::clone(pattern),
|
||||||
l_pattern: l.pattern().as_ref().clone(),
|
l_pattern,
|
||||||
u_pattern: u.pattern().as_ref().clone()
|
u_pattern,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,10 +43,20 @@ pub struct CscCholesky<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||||
|
#[non_exhaustive]
|
||||||
pub enum CholeskyError {
|
pub enum CholeskyError {
|
||||||
|
/// The matrix is not positive definite.
|
||||||
|
NotPositiveDefinite,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Display for CholeskyError {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Matrix is not positive definite")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for CholeskyError {}
|
||||||
|
|
||||||
impl<T: RealField> CscCholesky<T> {
|
impl<T: RealField> CscCholesky<T> {
|
||||||
|
|
||||||
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
||||||
|
@ -181,9 +164,7 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
*self.work_x.get_unchecked_mut(p) = T::zero();
|
*self.work_x.get_unchecked_mut(p) = T::zero();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// self.ok = false;
|
return Err(CholeskyError::NotPositiveDefinite);
|
||||||
// TODO: Return indefinite error (i.e. encountered non-positive diagonal
|
|
||||||
unimplemented!()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,8 +177,8 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fn reach<T>(
|
fn reach(
|
||||||
m: &CscMatrix<T>,
|
pattern: &SparsityPattern,
|
||||||
j: usize,
|
j: usize,
|
||||||
max_j: usize,
|
max_j: usize,
|
||||||
tree: &[usize],
|
tree: &[usize],
|
||||||
|
@ -211,7 +192,7 @@ fn reach<T>(
|
||||||
let mut tmp = Vec::new();
|
let mut tmp = Vec::new();
|
||||||
let mut res = Vec::new();
|
let mut res = Vec::new();
|
||||||
|
|
||||||
for &irow in m.col(j).row_indices() {
|
for &irow in pattern.lane(j) {
|
||||||
let mut curr = irow;
|
let mut curr = irow;
|
||||||
while curr != usize::max_value() && curr <= max_j && !marks[curr] {
|
while curr != usize::max_value() && curr <= max_j && !marks[curr] {
|
||||||
marks[curr] = true;
|
marks[curr] = true;
|
||||||
|
@ -223,57 +204,45 @@ fn reach<T>(
|
||||||
mem::swap(&mut tmp, &mut res);
|
mem::swap(&mut tmp, &mut res);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Is this right?
|
|
||||||
res.sort_unstable();
|
res.sort_unstable();
|
||||||
|
|
||||||
out.append(&mut res);
|
out.append(&mut res);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn nonzero_pattern<T: Scalar + Zero>(
|
fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
|
||||||
m: &CscMatrix<T>
|
|
||||||
) -> (CscMatrix<T>, CscMatrix<T>) {
|
|
||||||
// TODO: In order to stay as faithful as possible to the original implementation,
|
|
||||||
// we here return full matrices, whereas we actually only need to construct sparsity patterns
|
|
||||||
|
|
||||||
let etree = elimination_tree(m);
|
let etree = elimination_tree(m);
|
||||||
let (nrows, ncols) = (m.nrows(), m.ncols());
|
// Note: We assume CSC, therefore rows == minor and cols == major
|
||||||
|
let (nrows, ncols) = (m.minor_dim(), m.major_dim());
|
||||||
let mut rows = Vec::with_capacity(m.nnz());
|
let mut rows = Vec::with_capacity(m.nnz());
|
||||||
// TODO: Use a Vec here instead
|
let mut col_offsets = Vec::with_capacity(ncols + 1);
|
||||||
let mut cols = unsafe { VectorN::new_uninitialized_generic(Dynamic::new(nrows), U1) };
|
|
||||||
let mut marks = Vec::new();
|
let mut marks = Vec::new();
|
||||||
|
|
||||||
// NOTE: the following will actually compute the non-zero pattern of
|
// NOTE: the following will actually compute the non-zero pattern of
|
||||||
// the transpose of l.
|
// the transpose of l.
|
||||||
|
col_offsets.push(0);
|
||||||
for i in 0..nrows {
|
for i in 0..nrows {
|
||||||
cols[i] = rows.len();
|
|
||||||
reach(m, i, i, &etree, &mut marks, &mut rows);
|
reach(m, i, i, &etree, &mut marks, &mut rows);
|
||||||
|
col_offsets.push(rows.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Get rid of this in particular
|
let u_pattern = SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows)
|
||||||
let mut vals = Vec::with_capacity(rows.len());
|
.unwrap();
|
||||||
unsafe {
|
|
||||||
vals.set_len(rows.len());
|
|
||||||
}
|
|
||||||
vals.shrink_to_fit();
|
|
||||||
|
|
||||||
// TODO: Remove this unnecessary conversion by using Vec throughout
|
// TODO: Avoid this transpose?
|
||||||
let mut cols: Vec<_> = cols.iter().cloned().collect();
|
let l_pattern = u_pattern.transpose();
|
||||||
cols.push(rows.len());
|
|
||||||
|
|
||||||
let u = CscMatrix::try_from_csc_data(nrows, ncols, cols, rows, vals).unwrap();
|
(l_pattern, u_pattern)
|
||||||
// TODO: Avoid this transpose
|
|
||||||
let l = u.transpose();
|
|
||||||
|
|
||||||
(l, u)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn elimination_tree<T>(m: &CscMatrix<T>) -> Vec<usize> {
|
fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
|
||||||
let nrows = m.nrows();
|
// Note: The pattern is assumed to of a CSC matrix, so the number of rows is
|
||||||
|
// given by the minor dimension
|
||||||
|
let nrows = pattern.minor_dim();
|
||||||
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||||
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||||
|
|
||||||
for k in 0..nrows {
|
for k in 0..nrows {
|
||||||
for &irow in m.col(k).row_indices() {
|
for &irow in pattern.lane(k) {
|
||||||
let mut i = irow;
|
let mut i = irow;
|
||||||
|
|
||||||
while i < k {
|
while i < k {
|
||||||
|
|
Loading…
Reference in New Issue