Avoid bound-checking on cholesky decomposition.

This commit is contained in:
sebcrozet 2018-10-30 17:45:59 +01:00
parent 9bf1d0280d
commit 50d0b64924
1 changed files with 50 additions and 39 deletions

View File

@ -21,7 +21,6 @@ where
// equal to `original_i.len()` at the end. // equal to `original_i.len()` at the end.
original_p: Vec<usize>, original_p: Vec<usize>,
original_i: Vec<usize>, original_i: Vec<usize>,
original_len: usize, // Number of elements on the numerical value vector of the original matrix.
// Decomposition result. // Decomposition result.
l: CsMatrix<N, D, D>, l: CsMatrix<N, D, D>,
// Used only for the pattern. // Used only for the pattern.
@ -63,7 +62,6 @@ where
CsCholesky { CsCholesky {
original_p, original_p,
original_i: m.data.i.clone(), original_i: m.data.i.clone(),
original_len: m.data.i.len(),
l, l,
u, u,
ok: false, ok: false,
@ -91,7 +89,7 @@ where
// Performs the numerical Cholesky decomposition given the set of numerical values. // Performs the numerical Cholesky decomposition given the set of numerical values.
pub fn decompose(&mut self, values: &[N]) -> bool { pub fn decompose(&mut self, values: &[N]) -> bool {
assert!( assert!(
values.len() >= self.original_len, values.len() >= self.original_i.len(),
"The set of values is too small." "The set of values is too small."
); );
@ -100,20 +98,22 @@ where
// Perform the decomposition. // Perform the decomposition.
for k in 0..self.l.nrows() { for k in 0..self.l.nrows() {
unsafe {
// Scatter the k-th column of the original matrix with the values provided. // Scatter the k-th column of the original matrix with the values provided.
let column_range = self.original_p[k]..self.original_p[k + 1]; let column_range =
*self.original_p.get_unchecked(k)..*self.original_p.get_unchecked(k + 1);
self.work_x[k] = N::zero(); *self.work_x.vget_unchecked_mut(k) = N::zero();
for p in column_range.clone() { for p in column_range.clone() {
let irow = self.original_i[p]; let irow = *self.original_i.get_unchecked(p);
if irow <= k { if irow <= k {
self.work_x[irow] = values[p]; *self.work_x.vget_unchecked_mut(irow) = *values.get_unchecked(p);
} }
} }
let mut diag = self.work_x[k]; let mut diag = *self.work_x.vget_unchecked(k);
self.work_x[k] = N::zero(); *self.work_x.vget_unchecked_mut(k) = N::zero();
// Triangular solve. // Triangular solve.
for irow in self.u.data.column_row_indices(k) { for irow in self.u.data.column_row_indices(k) {
@ -121,18 +121,28 @@ where
continue; continue;
} }
let lki = self.work_x[irow] / self.l.data.vals[self.l.data.p[irow]]; let lki = *self.work_x.vget_unchecked(irow)
self.work_x[irow] = N::zero(); / *self
.l
.data
.vals
.get_unchecked(*self.l.data.p.vget_unchecked(irow));
*self.work_x.vget_unchecked_mut(irow) = N::zero();
for p in self.l.data.p[irow] + 1..self.work_c[irow] { for p in
self.work_x[self.l.data.i[p]] -= self.l.data.vals[p] * lki; *self.l.data.p.vget_unchecked(irow) + 1..*self.work_c.vget_unchecked(irow)
{
*self
.work_x
.vget_unchecked_mut(*self.l.data.i.get_unchecked(p)) -=
*self.l.data.vals.get_unchecked(p) * lki;
} }
diag -= lki * lki; diag -= lki * lki;
let p = self.work_c[irow]; let p = *self.work_c.vget_unchecked(irow);
self.work_c[irow] += 1; *self.work_c.vget_unchecked_mut(irow) += 1;
self.l.data.i[p] = k; *self.l.data.i.get_unchecked_mut(p) = k;
self.l.data.vals[p] = lki; *self.l.data.vals.get_unchecked_mut(p) = lki;
} }
if diag <= N::zero() { if diag <= N::zero() {
@ -141,10 +151,11 @@ where
} }
// Deal with the diagonal element. // Deal with the diagonal element.
let p = self.work_c[k]; let p = *self.work_c.vget_unchecked(k);
self.work_c[k] += 1; *self.work_c.vget_unchecked_mut(k) += 1;
self.l.data.i[p] = k; *self.l.data.i.get_unchecked_mut(p) = k;
self.l.data.vals[p] = diag.sqrt(); *self.l.data.vals.get_unchecked_mut(p) = diag.sqrt();
}
} }
self.ok = true; self.ok = true;