Replace Arc<SparsityPattern> with SparsityPattern

After much deliberation, I have come to the conclusion that the
benefits do not really outweigh the added complexity. Even though
the added complexity is relatively minor, it makes it somewhat
more complicated to inter-op with other sparse linear algebra
libraries in the future.
This commit is contained in:
Andreas Longva 2021-01-19 16:53:39 +01:00
parent 9b46a43c7f
commit e655fed4fa
8 changed files with 86 additions and 136 deletions

View File

@ -1,6 +1,5 @@
use std::mem::replace; use std::mem::replace;
use std::ops::Range; use std::ops::Range;
use std::sync::Arc;
use num_traits::One; use num_traits::One;
@ -18,7 +17,7 @@ use crate::pattern::SparsityPattern;
/// is obtained by associating columns with the major dimension. /// is obtained by associating columns with the major dimension.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsMatrix<T> { pub struct CsMatrix<T> {
sparsity_pattern: Arc<SparsityPattern>, sparsity_pattern: SparsityPattern,
values: Vec<T> values: Vec<T>
} }
@ -27,13 +26,13 @@ impl<T> CsMatrix<T> {
#[inline] #[inline]
pub fn new(major_dim: usize, minor_dim: usize) -> Self { pub fn new(major_dim: usize, minor_dim: usize) -> Self {
Self { Self {
sparsity_pattern: Arc::new(SparsityPattern::new(major_dim, minor_dim)), sparsity_pattern: SparsityPattern::new(major_dim, minor_dim),
values: vec![], values: vec![],
} }
} }
#[inline] #[inline]
pub fn pattern(&self) -> &Arc<SparsityPattern> { pub fn pattern(&self) -> &SparsityPattern {
&self.sparsity_pattern &self.sparsity_pattern
} }
@ -50,24 +49,24 @@ impl<T> CsMatrix<T> {
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`. /// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline] #[inline]
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) { pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
let pattern = self.pattern().as_ref(); let pattern = self.pattern();
(pattern.major_offsets(), pattern.minor_indices(), &self.values) (pattern.major_offsets(), pattern.minor_indices(), &self.values)
} }
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`. /// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline] #[inline]
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) { pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = self.sparsity_pattern.as_ref(); let pattern = &mut self.sparsity_pattern;
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values) (pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
} }
#[inline] #[inline]
pub fn pattern_and_values_mut(&mut self) -> (&Arc<SparsityPattern>, &mut [T]) { pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
(&self.sparsity_pattern, &mut self.values) (&self.sparsity_pattern, &mut self.values)
} }
#[inline] #[inline]
pub fn from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>) pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
-> Self { -> Self {
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility."); assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
Self { Self {
@ -84,25 +83,14 @@ impl<T> CsMatrix<T> {
Some(row_begin .. row_end) Some(row_begin .. row_end)
} }
pub fn take_pattern_and_values(self) -> (Arc<SparsityPattern>, Vec<T>) { pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
(self.sparsity_pattern, self.values) (self.sparsity_pattern, self.values)
} }
#[inline] #[inline]
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) { pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
// Take an Arc to the pattern, which might be the sole reference to the data after let (offsets, indices) = self.sparsity_pattern.disassemble();
// taking the values. This is important, because it might let us avoid cloning the data (offsets, indices, self.values)
// further below.
let pattern = self.sparsity_pattern;
let values = self.values;
// Try to take the pattern out of the `Arc` if possible,
// otherwise clone the pattern.
let owned_pattern = Arc::try_unwrap(pattern)
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
let (offsets, indices) = owned_pattern.disassemble();
(offsets, indices, values)
} }
/// Returns an entry for the given major/minor indices, or `None` if the indices are out /// Returns an entry for the given major/minor indices, or `None` if the indices are out
@ -151,12 +139,12 @@ impl<T> CsMatrix<T> {
#[inline] #[inline]
pub fn lane_iter(&self) -> CsLaneIter<T> { pub fn lane_iter(&self) -> CsLaneIter<T> {
CsLaneIter::new(self.pattern().as_ref(), self.values()) CsLaneIter::new(self.pattern(), self.values())
} }
#[inline] #[inline]
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> { pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values) CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
} }
#[inline] #[inline]
@ -190,7 +178,7 @@ impl<T> CsMatrix<T> {
new_indices) new_indices)
.expect("Internal error: Sparsity pattern must always be valid."); .expect("Internal error: Sparsity pattern must always be valid.");
Self::from_pattern_and_values(Arc::new(new_pattern), new_values) Self::from_pattern_and_values(new_pattern, new_values)
} }
} }
@ -205,7 +193,7 @@ impl<T: Scalar + One> CsMatrix<T> {
// TODO: We should skip checks here // TODO: We should skip checks here
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices) let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
.unwrap(); .unwrap();
Self::from_pattern_and_values(Arc::new(pattern), values) Self::from_pattern_and_values(pattern, values)
} }
} }

View File

@ -5,7 +5,6 @@ use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatter
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut}; use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
use num_traits::{One}; use num_traits::{One};
use nalgebra::Scalar; use nalgebra::Scalar;
@ -95,14 +94,14 @@ impl<T> CscMatrix<T> {
let pattern = SparsityPattern::try_from_offsets_and_indices( let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols, num_rows, col_offsets, row_indices) num_cols, num_rows, col_offsets, row_indices)
.map_err(pattern_format_error_to_csc_error)?; .map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(Arc::new(pattern), values) Self::try_from_pattern_and_values(pattern, values)
} }
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values. /// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
/// ///
/// Returns an error if the number of values does not match the number of minor indices /// Returns an error if the number of values does not match the number of minor indices
/// in the pattern. /// in the pattern.
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>) pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
-> Result<Self, SparseFormatError> { -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() { if pattern.nnz() == values.len() {
Ok(Self { Ok(Self {
@ -212,7 +211,7 @@ impl<T> CscMatrix<T> {
/// An iterator over columns in the matrix. /// An iterator over columns in the matrix.
pub fn col_iter(&self) -> CscColIter<T> { pub fn col_iter(&self) -> CscColIter<T> {
CscColIter { CscColIter {
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values()) lane_iter: CsLaneIter::new(self.pattern(), self.values())
} }
} }
@ -258,7 +257,7 @@ impl<T> CscMatrix<T> {
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use /// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple /// the same sparsity pattern for multiple matrices without storing the same pattern multiple
/// times in memory. /// times in memory.
pub fn pattern(&self) -> &Arc<SparsityPattern> { pub fn pattern(&self) -> &SparsityPattern {
self.cs.pattern() self.cs.pattern()
} }

View File

@ -7,7 +7,6 @@ use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::{One}; use num_traits::{One};
use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
/// A CSR representation of a sparse matrix. /// A CSR representation of a sparse matrix.
@ -97,14 +96,14 @@ impl<T> CsrMatrix<T> {
let pattern = SparsityPattern::try_from_offsets_and_indices( let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows, num_cols, row_offsets, col_indices) num_rows, num_cols, row_offsets, col_indices)
.map_err(pattern_format_error_to_csr_error)?; .map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(Arc::new(pattern), values) Self::try_from_pattern_and_values(pattern, values)
} }
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values. /// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
/// ///
/// Returns an error if the number of values does not match the number of minor indices /// Returns an error if the number of values does not match the number of minor indices
/// in the pattern. /// in the pattern.
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>) pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
-> Result<Self, SparseFormatError> { -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() { if pattern.nnz() == values.len() {
Ok(Self { Ok(Self {
@ -214,7 +213,7 @@ impl<T> CsrMatrix<T> {
/// An iterator over rows in the matrix. /// An iterator over rows in the matrix.
pub fn row_iter(&self) -> CsrRowIter<T> { pub fn row_iter(&self) -> CsrRowIter<T> {
CsrRowIter { CsrRowIter {
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values()) lane_iter: CsLaneIter::new(self.pattern(), self.values())
} }
} }
@ -260,7 +259,7 @@ impl<T> CsrMatrix<T> {
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use /// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple /// the same sparsity pattern for multiple matrices without storing the same pattern multiple
/// times in memory. /// times in memory.
pub fn pattern(&self) -> &Arc<SparsityPattern> { pub fn pattern(&self) -> &SparsityPattern {
self.cs.pattern() self.cs.pattern()
} }

View File

@ -5,26 +5,25 @@ use crate::pattern::SparsityPattern;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use core::{mem, iter}; use core::{mem, iter};
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix}; use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
use std::sync::Arc;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use crate::ops::serial::spsolve_csc_lower_triangular; use crate::ops::serial::spsolve_csc_lower_triangular;
use crate::ops::Op; use crate::ops::Op;
pub struct CscSymbolicCholesky { pub struct CscSymbolicCholesky {
// Pattern of the original matrix that was decomposed // Pattern of the original matrix that was decomposed
m_pattern: Arc<SparsityPattern>, m_pattern: SparsityPattern,
l_pattern: SparsityPattern, l_pattern: SparsityPattern,
// u in this context is L^T, so that M = L L^T // u in this context is L^T, so that M = L L^T
u_pattern: SparsityPattern u_pattern: SparsityPattern
} }
impl CscSymbolicCholesky { impl CscSymbolicCholesky {
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self { pub fn factor(pattern: SparsityPattern) -> Self {
assert_eq!(pattern.major_dim(), pattern.minor_dim(), assert_eq!(pattern.major_dim(), pattern.minor_dim(),
"Major and minor dimensions must be the same (square matrix)."); "Major and minor dimensions must be the same (square matrix).");
let (l_pattern, u_pattern) = nonzero_pattern(&*pattern); let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
Self { Self {
m_pattern: Arc::clone(pattern), m_pattern: pattern,
l_pattern, l_pattern,
u_pattern, u_pattern,
} }
@ -37,7 +36,7 @@ impl CscSymbolicCholesky {
pub struct CscCholesky<T> { pub struct CscCholesky<T> {
// Pattern of the original matrix // Pattern of the original matrix
m_pattern: Arc<SparsityPattern>, m_pattern: SparsityPattern,
l_factor: CscMatrix<T>, l_factor: CscMatrix<T>,
u_pattern: SparsityPattern, u_pattern: SparsityPattern,
work_x: Vec<T>, work_x: Vec<T>,
@ -66,7 +65,7 @@ impl<T: RealField> CscCholesky<T> {
let l_nnz = symbolic.l_pattern.nnz(); let l_nnz = symbolic.l_pattern.nnz();
let l_values = vec![T::zero(); l_nnz]; let l_values = vec![T::zero(); l_nnz];
let l_factor = CscMatrix::try_from_pattern_and_values(Arc::new(symbolic.l_pattern), l_values) let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values)
.unwrap(); .unwrap();
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols()); let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
@ -86,7 +85,7 @@ impl<T: RealField> CscCholesky<T> {
} }
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> { pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
let symbolic = CscSymbolicCholesky::factor(&*matrix.pattern()); let symbolic = CscSymbolicCholesky::factor(matrix.pattern().clone());
Self::factor_numerical(symbolic, matrix.values()) Self::factor_numerical(symbolic, matrix.values())
} }

View File

@ -7,7 +7,6 @@ use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim, use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
DMatrixSlice, DMatrix, Dynamic}; DMatrixSlice, DMatrix, Dynamic};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
use std::sync::Arc;
use crate::ops::{Op}; use crate::ops::{Op};
use nalgebra::base::storage::Storage; use nalgebra::base::storage::Storage;
@ -48,11 +47,7 @@ macro_rules! impl_sp_plus_minus {
impl_bin_op!($trait, $method, impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { <'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
// If both matrices have the same pattern, then we can immediately re-use it // If both matrices have the same pattern, then we can immediately re-use it
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) { let pattern = spadd_pattern(a.pattern(), b.pattern());
Arc::clone(a.pattern())
} else {
Arc::new(spadd_pattern(a.pattern(), b.pattern()))
};
let values = vec![T::zero(); pattern.nnz()]; let values = vec![T::zero(); pattern.nnz()];
// We are giving data that is valid by definition, so it is safe to unwrap below // We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values) let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
@ -64,24 +59,12 @@ macro_rules! impl_sp_plus_minus {
impl_bin_op!($trait, $method, impl_bin_op!($trait, $method,
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { <'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let mut a = a; &a $sign b
if Arc::ptr_eq(a.pattern(), b.pattern()) {
$spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap();
a
} else {
&a $sign b
}
}); });
impl_bin_op!($trait, $method, impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { <'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
let mut b = b; a $sign &b
if Arc::ptr_eq(a.pattern(), b.pattern()) {
$spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap();
b
} else {
a $sign &b
}
}); });
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b a $sign &b
@ -107,7 +90,7 @@ macro_rules! impl_spmm {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let pattern = $pattern_fn(a.pattern(), b.pattern()); let pattern = $pattern_fn(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()]; let values = vec![T::zero(); pattern.nnz()];
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values) let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap(); .unwrap();
$spmm_fn(T::zero(), $spmm_fn(T::zero(),
&mut result, &mut result,
@ -154,7 +137,7 @@ macro_rules! impl_scalar_mul {
.iter() .iter()
.map(|v_i| v_i.inlined_clone() * b.inlined_clone()) .map(|v_i| v_i.inlined_clone() * b.inlined_clone())
.collect(); .collect();
$matrix_type::try_from_pattern_and_values(Arc::clone(a.pattern()), values).unwrap() $matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
}); });
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> { impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b a * &b
@ -251,7 +234,7 @@ macro_rules! impl_div {
.iter() .iter()
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone()) .map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
.collect(); .collect();
$matrix_type::try_from_pattern_and_values(Arc::clone(matrix.pattern()), new_values) $matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
.unwrap() .unwrap()
}); });
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> { impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {

View File

@ -4,7 +4,6 @@ use crate::ops::serial::{OperationErrorType, OperationError};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice}; use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
use crate::SparseEntryMut; use crate::SparseEntryMut;
use std::sync::Arc;
fn spmm_cs_unexpected_entry() -> OperationError { fn spmm_cs_unexpected_entry() -> OperationError {
OperationError::from_type_and_message( OperationError::from_type_and_message(
@ -73,64 +72,52 @@ pub fn spadd_cs_prealloc<T>(beta: T,
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
match a {
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) { Op::NoOp(a) => {
// Special fast path: The two matrices have *exactly* the same sparsity pattern, for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
// so we only need to sum the value arrays
// TODO: Test this fast path
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
}
Ok(())
} else {
match a {
Op::NoOp(a) => {
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
if beta != T::one() {
for c_ij in c_lane_i.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
// TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C.
let (c_idx, _) = c_minors.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_cs_unexpected_entry)?;
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
c_minors = &c_minors[c_idx ..];
c_vals = &mut c_vals[c_idx ..];
}
}
}
Op::Transpose(a) => {
if beta != T::one() { if beta != T::one() {
for c_ij in c.values_mut() { for c_ij in c_lane_i.values_mut() {
*c_ij *= beta.inlined_clone(); *c_ij *= beta.inlined_clone();
} }
} }
for (i, a_lane_i) in a.lane_iter().enumerate() { let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) { let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone(); for (a_col, a_val) in a_minors.iter().zip(a_vals) {
match c.get_entry_mut(j, i).unwrap() { // TODO: Use exponential search instead of linear search.
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val } // If C has substantially more entries in the row than A, then a line search
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()), // will needlessly visit many entries in C.
} let (c_idx, _) = c_minors.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_cs_unexpected_entry)?;
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
c_minors = &c_minors[c_idx ..];
c_vals = &mut c_vals[c_idx ..];
}
}
}
Op::Transpose(a) => {
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_lane_i) in a.lane_iter().enumerate() {
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.get_entry_mut(j, i).unwrap() {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
} }
} }
} }
} }
Ok(())
} }
Ok(())
} }
/// Helper functionality for implementing CSR/CSC SPMM. /// Helper functionality for implementing CSR/CSC SPMM.

View File

@ -11,7 +11,6 @@ use std::iter::{repeat};
use proptest::sample::{Index}; use proptest::sample::{Index};
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::pattern::SparsityPattern; use crate::pattern::SparsityPattern;
use std::sync::Arc;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize) fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
@ -291,7 +290,7 @@ where
(Just(pattern), values) (Just(pattern), values)
}) })
.prop_map(|(pattern, values)| { .prop_map(|(pattern, values)| {
CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) CsrMatrix::try_from_pattern_and_values(pattern, values)
.expect("Internal error: Generated CsrMatrix is invalid") .expect("Internal error: Generated CsrMatrix is invalid")
}) })
} }
@ -313,7 +312,7 @@ pub fn csc<T>(value_strategy: T,
(Just(pattern), values) (Just(pattern), values)
}) })
.prop_map(|(pattern, values)| { .prop_map(|(pattern, values)| {
CscMatrix::try_from_pattern_and_values(Arc::new(pattern), values) CscMatrix::try_from_pattern_and_values(pattern, values)
.expect("Internal error: Generated CscMatrix is invalid") .expect("Internal error: Generated CscMatrix is invalid")
}) })
} }

View File

@ -17,12 +17,11 @@ use proptest::prelude::*;
use matrixcompare::prop_assert_matrix_eq; use matrixcompare::prop_assert_matrix_eq;
use std::panic::catch_unwind; use std::panic::catch_unwind;
use std::sync::Arc;
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1 /// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> { fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csr = CsrMatrix::try_from_pattern_and_values( let boolean_csr = CsrMatrix::try_from_pattern_and_values(
Arc::new(pattern.clone()), pattern.clone(),
vec![1; pattern.nnz()]) vec![1; pattern.nnz()])
.unwrap(); .unwrap();
DMatrix::from(&boolean_csr) DMatrix::from(&boolean_csr)
@ -31,7 +30,7 @@ fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1 /// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> { fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csc = CscMatrix::try_from_pattern_and_values( let boolean_csc = CscMatrix::try_from_pattern_and_values(
Arc::new(pattern.clone()), pattern.clone(),
vec![1; pattern.nnz()]) vec![1; pattern.nnz()])
.unwrap(); .unwrap();
DMatrix::from(&boolean_csc) DMatrix::from(&boolean_csc)
@ -137,8 +136,8 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
let beta = value_strategy.clone(); let beta = value_strategy.clone();
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy()) (Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| { }).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap(); let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap(); let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) }; let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
SpaddCsrArgs { c, beta, alpha, a } SpaddCsrArgs { c, beta, alpha, a }
@ -222,15 +221,12 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()]; let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
let c_pattern = spmm_pattern(&a_pattern, &b_pattern); let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()]; let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
let a_pattern = Arc::new(a_pattern);
let b_pattern = Arc::new(b_pattern);
let c_pattern = Arc::new(c_pattern);
let a = a_values.prop_map(move |values| let a = a_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(Arc::clone(&a_pattern), values).unwrap()); CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
let b = b_values.prop_map(move |values| let b = b_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(Arc::clone(&b_pattern), values).unwrap()); CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap());
let c = c_values.prop_map(move |values| let c = c_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(Arc::clone(&c_pattern), values).unwrap()); CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap());
let alpha = PROPTEST_I32_VALUE_STRATEGY; let alpha = PROPTEST_I32_VALUE_STRATEGY;
let beta = PROPTEST_I32_VALUE_STRATEGY; let beta = PROPTEST_I32_VALUE_STRATEGY;
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b) (c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
@ -383,16 +379,16 @@ proptest! {
// corresponding to a and b, and convert them to dense matrices. // corresponding to a and b, and convert them to dense matrices.
// The sum of these dense matrices will then have non-zeros in exactly the same locations // The sum of these dense matrices will then have non-zeros in exactly the same locations
// as the result of "adding" the sparsity patterns // as the result of "adding" the sparsity patterns
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()]) let a_csr = CsrMatrix::try_from_pattern_and_values(a.clone(), vec![1; a.nnz()])
.unwrap(); .unwrap();
let a_dense = DMatrix::from(&a_csr); let a_dense = DMatrix::from(&a_csr);
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()]) let b_csr = CsrMatrix::try_from_pattern_and_values(b.clone(), vec![1; b.nnz()])
.unwrap(); .unwrap();
let b_dense = DMatrix::from(&b_csr); let b_dense = DMatrix::from(&b_csr);
let c_dense = a_dense + b_dense; let c_dense = a_dense + b_dense;
let c_csr = CsrMatrix::from(&c_dense); let c_csr = CsrMatrix::from(&c_dense);
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref()); prop_assert_eq!(&pattern_result, c_csr.pattern());
} }
#[test] #[test]
@ -492,16 +488,16 @@ proptest! {
// corresponding to a and b, and convert them to dense matrices. // corresponding to a and b, and convert them to dense matrices.
// The product of these dense matrices will then have non-zeros in exactly the same locations // The product of these dense matrices will then have non-zeros in exactly the same locations
// as the result of "multiplying" the sparsity patterns // as the result of "multiplying" the sparsity patterns
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()]) let a_csr = CsrMatrix::try_from_pattern_and_values(a.clone(), vec![1; a.nnz()])
.unwrap(); .unwrap();
let a_dense = DMatrix::from(&a_csr); let a_dense = DMatrix::from(&a_csr);
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()]) let b_csr = CsrMatrix::try_from_pattern_and_values(b.clone(), vec![1; b.nnz()])
.unwrap(); .unwrap();
let b_dense = DMatrix::from(&b_csr); let b_dense = DMatrix::from(&b_csr);
let c_dense = a_dense * b_dense; let c_dense = a_dense * b_dense;
let c_csr = CsrMatrix::from(&c_dense); let c_csr = CsrMatrix::from(&c_dense);
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref()); prop_assert_eq!(&c_pattern, c_csr.pattern());
} }
#[test] #[test]