Fix nalgebra-sparse.
This commit is contained in:
parent
148b164aaa
commit
0b9a1acea5
|
@ -3,7 +3,7 @@ use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||||
use crate::ops::Op;
|
use crate::ops::Op;
|
||||||
use crate::pattern::SparsityPattern;
|
use crate::pattern::SparsityPattern;
|
||||||
use core::{iter, mem};
|
use core::{iter, mem};
|
||||||
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField};
|
||||||
use std::fmt::{Display, Formatter};
|
use std::fmt::{Display, Formatter};
|
||||||
|
|
||||||
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
||||||
|
@ -209,15 +209,16 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
let irow = *self.m_pattern.minor_indices().get_unchecked(p);
|
let irow = *self.m_pattern.minor_indices().get_unchecked(p);
|
||||||
|
|
||||||
if irow >= k {
|
if irow >= k {
|
||||||
*self.work_x.get_unchecked_mut(irow) = *values.get_unchecked(p);
|
*self.work_x.get_unchecked_mut(irow) = values.get_unchecked(p).clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for &j in self.u_pattern.lane(k) {
|
for &j in self.u_pattern.lane(k) {
|
||||||
let factor = -*self
|
let factor = -self
|
||||||
.l_factor
|
.l_factor
|
||||||
.values()
|
.values()
|
||||||
.get_unchecked(*self.work_c.get_unchecked(j));
|
.get_unchecked(*self.work_c.get_unchecked(j))
|
||||||
|
.clone();
|
||||||
*self.work_c.get_unchecked_mut(j) += 1;
|
*self.work_c.get_unchecked_mut(j) += 1;
|
||||||
|
|
||||||
if j < k {
|
if j < k {
|
||||||
|
@ -225,27 +226,27 @@ impl<T: RealField> CscCholesky<T> {
|
||||||
let col_j_entries = col_j.row_indices().iter().zip(col_j.values());
|
let col_j_entries = col_j.row_indices().iter().zip(col_j.values());
|
||||||
for (&z, val) in col_j_entries {
|
for (&z, val) in col_j_entries {
|
||||||
if z >= k {
|
if z >= k {
|
||||||
*self.work_x.get_unchecked_mut(z) += val.clone() * factor;
|
*self.work_x.get_unchecked_mut(z) += val.clone() * factor.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let diag = *self.work_x.get_unchecked(k);
|
let diag = self.work_x.get_unchecked(k).clone();
|
||||||
|
|
||||||
if diag > T::zero() {
|
if diag > T::zero() {
|
||||||
let denom = diag.sqrt();
|
let denom = diag.sqrt();
|
||||||
|
|
||||||
{
|
{
|
||||||
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
||||||
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut col_k = self.l_factor.col_mut(k);
|
let mut col_k = self.l_factor.col_mut(k);
|
||||||
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
||||||
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
||||||
for (&p, val) in col_k_entries {
|
for (&p, val) in col_k_entries {
|
||||||
*val = *self.work_x.get_unchecked(p) / denom;
|
*val = self.work_x.get_unchecked(p).clone() / denom.clone();
|
||||||
*self.work_x.get_unchecked_mut(p) = T::zero();
|
*self.work_x.get_unchecked_mut(p) = T::zero();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -165,13 +165,13 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||||
// a severe penalty)
|
// a severe penalty)
|
||||||
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
|
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
|
||||||
if let Some(diag_csc_index) = diag_csc_index {
|
if let Some(diag_csc_index) = diag_csc_index {
|
||||||
let l_kk = l_col_k.values()[diag_csc_index];
|
let l_kk = l_col_k.values()[diag_csc_index].clone();
|
||||||
|
|
||||||
if l_kk != T::zero() {
|
if l_kk != T::zero() {
|
||||||
// Update entry associated with diagonal
|
// Update entry associated with diagonal
|
||||||
x_col_j[k] /= l_kk;
|
x_col_j[k] /= l_kk;
|
||||||
// Copy value after updating (so we don't run into the borrow checker)
|
// Copy value after updating (so we don't run into the borrow checker)
|
||||||
let x_kj = x_col_j[k];
|
let x_kj = x_col_j[k].clone();
|
||||||
|
|
||||||
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
|
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
|
||||||
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
|
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
|
||||||
|
@ -179,7 +179,7 @@ fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||||
// Note: The remaining entries are below the diagonal
|
// Note: The remaining entries are below the diagonal
|
||||||
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
||||||
let x_ij = &mut x_col_j[i];
|
let x_ij = &mut x_col_j[i];
|
||||||
*x_ij -= l_ik.clone() * x_kj;
|
*x_ij -= l_ik.clone() * x_kj.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
x_col_j[k] = x_kj;
|
x_col_j[k] = x_kj;
|
||||||
|
@ -223,22 +223,22 @@ fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||||
// TODO: Can use exponential search here to quickly skip entries
|
// TODO: Can use exponential search here to quickly skip entries
|
||||||
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
|
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
|
||||||
if let Some(diag_csc_index) = diag_csc_index {
|
if let Some(diag_csc_index) = diag_csc_index {
|
||||||
let l_ii = l_col_i.values()[diag_csc_index];
|
let l_ii = l_col_i.values()[diag_csc_index].clone();
|
||||||
|
|
||||||
if l_ii != T::zero() {
|
if l_ii != T::zero() {
|
||||||
// // Update entry associated with diagonal
|
// // Update entry associated with diagonal
|
||||||
// x_col_j[k] /= a_kk;
|
// x_col_j[k] /= a_kk;
|
||||||
|
|
||||||
// Copy value after updating (so we don't run into the borrow checker)
|
// Copy value after updating (so we don't run into the borrow checker)
|
||||||
let mut x_ii = x_col_j[i];
|
let mut x_ii = x_col_j[i].clone();
|
||||||
|
|
||||||
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
|
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
|
||||||
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
|
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
|
||||||
|
|
||||||
// Note: The remaining entries are below the diagonal
|
// Note: The remaining entries are below the diagonal
|
||||||
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
for (k, l_ki) in row_indices.iter().zip(a_values) {
|
||||||
let x_kj = x_col_j[k];
|
let x_kj = x_col_j[*k].clone();
|
||||||
x_ii -= l_ki * x_kj;
|
x_ii -= l_ki.clone() * x_kj;
|
||||||
}
|
}
|
||||||
|
|
||||||
x_col_j[i] = x_ii / l_ii;
|
x_col_j[i] = x_ii / l_ii;
|
||||||
|
|
Loading…
Reference in New Issue