Avoid bound-checking on cholesky decomposition.

This commit is contained in:
sebcrozet 2018-10-30 17:45:59 +01:00
parent 9bf1d0280d
commit 50d0b64924
1 changed files with 50 additions and 39 deletions

View File

@ -21,7 +21,6 @@ where
// equal to `original_i.len()` at the end.
original_p: Vec<usize>,
original_i: Vec<usize>,
original_len: usize, // Number of elements on the numerical value vector of the original matrix.
// Decomposition result.
l: CsMatrix<N, D, D>,
// Used only for the pattern.
@ -63,7 +62,6 @@ where
CsCholesky {
original_p,
original_i: m.data.i.clone(),
original_len: m.data.i.len(),
l,
u,
ok: false,
@ -91,7 +89,7 @@ where
// Performs the numerical Cholesky decomposition given the set of numerical values.
pub fn decompose(&mut self, values: &[N]) -> bool {
assert!(
values.len() >= self.original_len,
values.len() >= self.original_i.len(),
"The set of values is too small."
);
@ -100,20 +98,22 @@ where
// Perform the decomposition.
for k in 0..self.l.nrows() {
unsafe {
// Scatter the k-th column of the original matrix with the values provided.
let column_range = self.original_p[k]..self.original_p[k + 1];
let column_range =
*self.original_p.get_unchecked(k)..*self.original_p.get_unchecked(k + 1);
self.work_x[k] = N::zero();
*self.work_x.vget_unchecked_mut(k) = N::zero();
for p in column_range.clone() {
let irow = self.original_i[p];
let irow = *self.original_i.get_unchecked(p);
if irow <= k {
self.work_x[irow] = values[p];
*self.work_x.vget_unchecked_mut(irow) = *values.get_unchecked(p);
}
}
let mut diag = self.work_x[k];
self.work_x[k] = N::zero();
let mut diag = *self.work_x.vget_unchecked(k);
*self.work_x.vget_unchecked_mut(k) = N::zero();
// Triangular solve.
for irow in self.u.data.column_row_indices(k) {
@ -121,18 +121,28 @@ where
continue;
}
let lki = self.work_x[irow] / self.l.data.vals[self.l.data.p[irow]];
self.work_x[irow] = N::zero();
let lki = *self.work_x.vget_unchecked(irow)
/ *self
.l
.data
.vals
.get_unchecked(*self.l.data.p.vget_unchecked(irow));
*self.work_x.vget_unchecked_mut(irow) = N::zero();
for p in self.l.data.p[irow] + 1..self.work_c[irow] {
self.work_x[self.l.data.i[p]] -= self.l.data.vals[p] * lki;
for p in
*self.l.data.p.vget_unchecked(irow) + 1..*self.work_c.vget_unchecked(irow)
{
*self
.work_x
.vget_unchecked_mut(*self.l.data.i.get_unchecked(p)) -=
*self.l.data.vals.get_unchecked(p) * lki;
}
diag -= lki * lki;
let p = self.work_c[irow];
self.work_c[irow] += 1;
self.l.data.i[p] = k;
self.l.data.vals[p] = lki;
let p = *self.work_c.vget_unchecked(irow);
*self.work_c.vget_unchecked_mut(irow) += 1;
*self.l.data.i.get_unchecked_mut(p) = k;
*self.l.data.vals.get_unchecked_mut(p) = lki;
}
if diag <= N::zero() {
@ -141,10 +151,11 @@ where
}
// Deal with the diagonal element.
let p = self.work_c[k];
self.work_c[k] += 1;
self.l.data.i[p] = k;
self.l.data.vals[p] = diag.sqrt();
let p = *self.work_c.vget_unchecked(k);
*self.work_c.vget_unchecked_mut(k) += 1;
*self.l.data.i.get_unchecked_mut(p) = k;
*self.l.data.vals.get_unchecked_mut(p) = diag.sqrt();
}
}
self.ok = true;