Avoid bound-checking on cholesky decomposition.
This commit is contained in:
parent
9bf1d0280d
commit
50d0b64924
|
@ -21,7 +21,6 @@ where
|
|||
// equal to `original_i.len()` at the end.
|
||||
original_p: Vec<usize>,
|
||||
original_i: Vec<usize>,
|
||||
original_len: usize, // Number of elements on the numerical value vector of the original matrix.
|
||||
// Decomposition result.
|
||||
l: CsMatrix<N, D, D>,
|
||||
// Used only for the pattern.
|
||||
|
@ -63,7 +62,6 @@ where
|
|||
CsCholesky {
|
||||
original_p,
|
||||
original_i: m.data.i.clone(),
|
||||
original_len: m.data.i.len(),
|
||||
l,
|
||||
u,
|
||||
ok: false,
|
||||
|
@ -91,7 +89,7 @@ where
|
|||
// Performs the numerical Cholesky decomposition given the set of numerical values.
|
||||
pub fn decompose(&mut self, values: &[N]) -> bool {
|
||||
assert!(
|
||||
values.len() >= self.original_len,
|
||||
values.len() >= self.original_i.len(),
|
||||
"The set of values is too small."
|
||||
);
|
||||
|
||||
|
@ -100,51 +98,64 @@ where
|
|||
|
||||
// Perform the decomposition.
|
||||
for k in 0..self.l.nrows() {
|
||||
// Scatter the k-th column of the original matrix with the values provided.
|
||||
let column_range = self.original_p[k]..self.original_p[k + 1];
|
||||
unsafe {
|
||||
// Scatter the k-th column of the original matrix with the values provided.
|
||||
let column_range =
|
||||
*self.original_p.get_unchecked(k)..*self.original_p.get_unchecked(k + 1);
|
||||
|
||||
self.work_x[k] = N::zero();
|
||||
for p in column_range.clone() {
|
||||
let irow = self.original_i[p];
|
||||
*self.work_x.vget_unchecked_mut(k) = N::zero();
|
||||
for p in column_range.clone() {
|
||||
let irow = *self.original_i.get_unchecked(p);
|
||||
|
||||
if irow <= k {
|
||||
self.work_x[irow] = values[p];
|
||||
}
|
||||
}
|
||||
|
||||
let mut diag = self.work_x[k];
|
||||
self.work_x[k] = N::zero();
|
||||
|
||||
// Triangular solve.
|
||||
for irow in self.u.data.column_row_indices(k) {
|
||||
if irow >= k {
|
||||
continue;
|
||||
if irow <= k {
|
||||
*self.work_x.vget_unchecked_mut(irow) = *values.get_unchecked(p);
|
||||
}
|
||||
}
|
||||
|
||||
let lki = self.work_x[irow] / self.l.data.vals[self.l.data.p[irow]];
|
||||
self.work_x[irow] = N::zero();
|
||||
let mut diag = *self.work_x.vget_unchecked(k);
|
||||
*self.work_x.vget_unchecked_mut(k) = N::zero();
|
||||
|
||||
for p in self.l.data.p[irow] + 1..self.work_c[irow] {
|
||||
self.work_x[self.l.data.i[p]] -= self.l.data.vals[p] * lki;
|
||||
// Triangular solve.
|
||||
for irow in self.u.data.column_row_indices(k) {
|
||||
if irow >= k {
|
||||
continue;
|
||||
}
|
||||
|
||||
let lki = *self.work_x.vget_unchecked(irow)
|
||||
/ *self
|
||||
.l
|
||||
.data
|
||||
.vals
|
||||
.get_unchecked(*self.l.data.p.vget_unchecked(irow));
|
||||
*self.work_x.vget_unchecked_mut(irow) = N::zero();
|
||||
|
||||
for p in
|
||||
*self.l.data.p.vget_unchecked(irow) + 1..*self.work_c.vget_unchecked(irow)
|
||||
{
|
||||
*self
|
||||
.work_x
|
||||
.vget_unchecked_mut(*self.l.data.i.get_unchecked(p)) -=
|
||||
*self.l.data.vals.get_unchecked(p) * lki;
|
||||
}
|
||||
|
||||
diag -= lki * lki;
|
||||
let p = *self.work_c.vget_unchecked(irow);
|
||||
*self.work_c.vget_unchecked_mut(irow) += 1;
|
||||
*self.l.data.i.get_unchecked_mut(p) = k;
|
||||
*self.l.data.vals.get_unchecked_mut(p) = lki;
|
||||
}
|
||||
|
||||
diag -= lki * lki;
|
||||
let p = self.work_c[irow];
|
||||
self.work_c[irow] += 1;
|
||||
self.l.data.i[p] = k;
|
||||
self.l.data.vals[p] = lki;
|
||||
}
|
||||
if diag <= N::zero() {
|
||||
self.ok = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
if diag <= N::zero() {
|
||||
self.ok = false;
|
||||
return false;
|
||||
// Deal with the diagonal element.
|
||||
let p = *self.work_c.vget_unchecked(k);
|
||||
*self.work_c.vget_unchecked_mut(k) += 1;
|
||||
*self.l.data.i.get_unchecked_mut(p) = k;
|
||||
*self.l.data.vals.get_unchecked_mut(p) = diag.sqrt();
|
||||
}
|
||||
|
||||
// Deal with the diagonal element.
|
||||
let p = self.work_c[k];
|
||||
self.work_c[k] += 1;
|
||||
self.l.data.i[p] = k;
|
||||
self.l.data.vals[p] = diag.sqrt();
|
||||
}
|
||||
|
||||
self.ok = true;
|
||||
|
|
Loading…
Reference in New Issue