Avoid bound-checking on cholesky decomposition.
This commit is contained in:
parent
9bf1d0280d
commit
50d0b64924
|
@ -21,7 +21,6 @@ where
|
||||||
// equal to `original_i.len()` at the end.
|
// equal to `original_i.len()` at the end.
|
||||||
original_p: Vec<usize>,
|
original_p: Vec<usize>,
|
||||||
original_i: Vec<usize>,
|
original_i: Vec<usize>,
|
||||||
original_len: usize, // Number of elements on the numerical value vector of the original matrix.
|
|
||||||
// Decomposition result.
|
// Decomposition result.
|
||||||
l: CsMatrix<N, D, D>,
|
l: CsMatrix<N, D, D>,
|
||||||
// Used only for the pattern.
|
// Used only for the pattern.
|
||||||
|
@ -63,7 +62,6 @@ where
|
||||||
CsCholesky {
|
CsCholesky {
|
||||||
original_p,
|
original_p,
|
||||||
original_i: m.data.i.clone(),
|
original_i: m.data.i.clone(),
|
||||||
original_len: m.data.i.len(),
|
|
||||||
l,
|
l,
|
||||||
u,
|
u,
|
||||||
ok: false,
|
ok: false,
|
||||||
|
@ -91,7 +89,7 @@ where
|
||||||
// Performs the numerical Cholesky decomposition given the set of numerical values.
|
// Performs the numerical Cholesky decomposition given the set of numerical values.
|
||||||
pub fn decompose(&mut self, values: &[N]) -> bool {
|
pub fn decompose(&mut self, values: &[N]) -> bool {
|
||||||
assert!(
|
assert!(
|
||||||
values.len() >= self.original_len,
|
values.len() >= self.original_i.len(),
|
||||||
"The set of values is too small."
|
"The set of values is too small."
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -100,20 +98,22 @@ where
|
||||||
|
|
||||||
// Perform the decomposition.
|
// Perform the decomposition.
|
||||||
for k in 0..self.l.nrows() {
|
for k in 0..self.l.nrows() {
|
||||||
|
unsafe {
|
||||||
// Scatter the k-th column of the original matrix with the values provided.
|
// Scatter the k-th column of the original matrix with the values provided.
|
||||||
let column_range = self.original_p[k]..self.original_p[k + 1];
|
let column_range =
|
||||||
|
*self.original_p.get_unchecked(k)..*self.original_p.get_unchecked(k + 1);
|
||||||
|
|
||||||
self.work_x[k] = N::zero();
|
*self.work_x.vget_unchecked_mut(k) = N::zero();
|
||||||
for p in column_range.clone() {
|
for p in column_range.clone() {
|
||||||
let irow = self.original_i[p];
|
let irow = *self.original_i.get_unchecked(p);
|
||||||
|
|
||||||
if irow <= k {
|
if irow <= k {
|
||||||
self.work_x[irow] = values[p];
|
*self.work_x.vget_unchecked_mut(irow) = *values.get_unchecked(p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut diag = self.work_x[k];
|
let mut diag = *self.work_x.vget_unchecked(k);
|
||||||
self.work_x[k] = N::zero();
|
*self.work_x.vget_unchecked_mut(k) = N::zero();
|
||||||
|
|
||||||
// Triangular solve.
|
// Triangular solve.
|
||||||
for irow in self.u.data.column_row_indices(k) {
|
for irow in self.u.data.column_row_indices(k) {
|
||||||
|
@ -121,18 +121,28 @@ where
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let lki = self.work_x[irow] / self.l.data.vals[self.l.data.p[irow]];
|
let lki = *self.work_x.vget_unchecked(irow)
|
||||||
self.work_x[irow] = N::zero();
|
/ *self
|
||||||
|
.l
|
||||||
|
.data
|
||||||
|
.vals
|
||||||
|
.get_unchecked(*self.l.data.p.vget_unchecked(irow));
|
||||||
|
*self.work_x.vget_unchecked_mut(irow) = N::zero();
|
||||||
|
|
||||||
for p in self.l.data.p[irow] + 1..self.work_c[irow] {
|
for p in
|
||||||
self.work_x[self.l.data.i[p]] -= self.l.data.vals[p] * lki;
|
*self.l.data.p.vget_unchecked(irow) + 1..*self.work_c.vget_unchecked(irow)
|
||||||
|
{
|
||||||
|
*self
|
||||||
|
.work_x
|
||||||
|
.vget_unchecked_mut(*self.l.data.i.get_unchecked(p)) -=
|
||||||
|
*self.l.data.vals.get_unchecked(p) * lki;
|
||||||
}
|
}
|
||||||
|
|
||||||
diag -= lki * lki;
|
diag -= lki * lki;
|
||||||
let p = self.work_c[irow];
|
let p = *self.work_c.vget_unchecked(irow);
|
||||||
self.work_c[irow] += 1;
|
*self.work_c.vget_unchecked_mut(irow) += 1;
|
||||||
self.l.data.i[p] = k;
|
*self.l.data.i.get_unchecked_mut(p) = k;
|
||||||
self.l.data.vals[p] = lki;
|
*self.l.data.vals.get_unchecked_mut(p) = lki;
|
||||||
}
|
}
|
||||||
|
|
||||||
if diag <= N::zero() {
|
if diag <= N::zero() {
|
||||||
|
@ -141,10 +151,11 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deal with the diagonal element.
|
// Deal with the diagonal element.
|
||||||
let p = self.work_c[k];
|
let p = *self.work_c.vget_unchecked(k);
|
||||||
self.work_c[k] += 1;
|
*self.work_c.vget_unchecked_mut(k) += 1;
|
||||||
self.l.data.i[p] = k;
|
*self.l.data.i.get_unchecked_mut(p) = k;
|
||||||
self.l.data.vals[p] = diag.sqrt();
|
*self.l.data.vals.get_unchecked_mut(p) = diag.sqrt();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ok = true;
|
self.ok = true;
|
||||||
|
|
Loading…
Reference in New Issue