forked from M-Labs/nalgebra
Replace Arc<SparsityPattern> with SparsityPattern
After much deliberation, I have come to the conclusion that the benefits do not really outweigh the added complexity. Even though the added complexity is relatively minor, it makes it somewhat more complicated to inter-op with other sparse linear algebra libraries in the future.
This commit is contained in:
parent
9b46a43c7f
commit
e655fed4fa
@ -1,6 +1,5 @@
|
|||||||
use std::mem::replace;
|
use std::mem::replace;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
|
||||||
@ -18,7 +17,7 @@ use crate::pattern::SparsityPattern;
|
|||||||
/// is obtained by associating columns with the major dimension.
|
/// is obtained by associating columns with the major dimension.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsMatrix<T> {
|
pub struct CsMatrix<T> {
|
||||||
sparsity_pattern: Arc<SparsityPattern>,
|
sparsity_pattern: SparsityPattern,
|
||||||
values: Vec<T>
|
values: Vec<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,13 +26,13 @@ impl<T> CsMatrix<T> {
|
|||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
sparsity_pattern: Arc::new(SparsityPattern::new(major_dim, minor_dim)),
|
sparsity_pattern: SparsityPattern::new(major_dim, minor_dim),
|
||||||
values: vec![],
|
values: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &SparsityPattern {
|
||||||
&self.sparsity_pattern
|
&self.sparsity_pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,24 +49,24 @@ impl<T> CsMatrix<T> {
|
|||||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||||
let pattern = self.pattern().as_ref();
|
let pattern = self.pattern();
|
||||||
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
|
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||||
let pattern = self.sparsity_pattern.as_ref();
|
let pattern = &mut self.sparsity_pattern;
|
||||||
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn pattern_and_values_mut(&mut self) -> (&Arc<SparsityPattern>, &mut [T]) {
|
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||||
(&self.sparsity_pattern, &mut self.values)
|
(&self.sparsity_pattern, &mut self.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||||
-> Self {
|
-> Self {
|
||||||
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
|
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
|
||||||
Self {
|
Self {
|
||||||
@ -84,25 +83,14 @@ impl<T> CsMatrix<T> {
|
|||||||
Some(row_begin .. row_end)
|
Some(row_begin .. row_end)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn take_pattern_and_values(self) -> (Arc<SparsityPattern>, Vec<T>) {
|
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||||
(self.sparsity_pattern, self.values)
|
(self.sparsity_pattern, self.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||||
// Take an Arc to the pattern, which might be the sole reference to the data after
|
let (offsets, indices) = self.sparsity_pattern.disassemble();
|
||||||
// taking the values. This is important, because it might let us avoid cloning the data
|
(offsets, indices, self.values)
|
||||||
// further below.
|
|
||||||
let pattern = self.sparsity_pattern;
|
|
||||||
let values = self.values;
|
|
||||||
|
|
||||||
// Try to take the pattern out of the `Arc` if possible,
|
|
||||||
// otherwise clone the pattern.
|
|
||||||
let owned_pattern = Arc::try_unwrap(pattern)
|
|
||||||
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
|
|
||||||
let (offsets, indices) = owned_pattern.disassemble();
|
|
||||||
|
|
||||||
(offsets, indices, values)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
||||||
@ -151,12 +139,12 @@ impl<T> CsMatrix<T> {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn lane_iter(&self) -> CsLaneIter<T> {
|
pub fn lane_iter(&self) -> CsLaneIter<T> {
|
||||||
CsLaneIter::new(self.pattern().as_ref(), self.values())
|
CsLaneIter::new(self.pattern(), self.values())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
||||||
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values)
|
CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -190,7 +178,7 @@ impl<T> CsMatrix<T> {
|
|||||||
new_indices)
|
new_indices)
|
||||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||||
|
|
||||||
Self::from_pattern_and_values(Arc::new(new_pattern), new_values)
|
Self::from_pattern_and_values(new_pattern, new_values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,7 +193,7 @@ impl<T: Scalar + One> CsMatrix<T> {
|
|||||||
// TODO: We should skip checks here
|
// TODO: We should skip checks here
|
||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
|
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Self::from_pattern_and_values(Arc::new(pattern), values)
|
Self::from_pattern_and_values(pattern, values)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatter
|
|||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
use num_traits::{One};
|
use num_traits::{One};
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
@ -95,14 +94,14 @@ impl<T> CscMatrix<T> {
|
|||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
num_cols, num_rows, col_offsets, row_indices)
|
num_cols, num_rows, col_offsets, row_indices)
|
||||||
.map_err(pattern_format_error_to_csc_error)?;
|
.map_err(pattern_format_error_to_csc_error)?;
|
||||||
Self::try_from_pattern_and_values(Arc::new(pattern), values)
|
Self::try_from_pattern_and_values(pattern, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
|
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
|
||||||
///
|
///
|
||||||
/// Returns an error if the number of values does not match the number of minor indices
|
/// Returns an error if the number of values does not match the number of minor indices
|
||||||
/// in the pattern.
|
/// in the pattern.
|
||||||
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||||
-> Result<Self, SparseFormatError> {
|
-> Result<Self, SparseFormatError> {
|
||||||
if pattern.nnz() == values.len() {
|
if pattern.nnz() == values.len() {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -212,7 +211,7 @@ impl<T> CscMatrix<T> {
|
|||||||
/// An iterator over columns in the matrix.
|
/// An iterator over columns in the matrix.
|
||||||
pub fn col_iter(&self) -> CscColIter<T> {
|
pub fn col_iter(&self) -> CscColIter<T> {
|
||||||
CscColIter {
|
CscColIter {
|
||||||
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
|
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +257,7 @@ impl<T> CscMatrix<T> {
|
|||||||
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
|
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
|
||||||
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
||||||
/// times in memory.
|
/// times in memory.
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &SparsityPattern {
|
||||||
self.cs.pattern()
|
self.cs.pattern()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
|||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::{One};
|
use num_traits::{One};
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
|
|
||||||
/// A CSR representation of a sparse matrix.
|
/// A CSR representation of a sparse matrix.
|
||||||
@ -97,14 +96,14 @@ impl<T> CsrMatrix<T> {
|
|||||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
num_rows, num_cols, row_offsets, col_indices)
|
num_rows, num_cols, row_offsets, col_indices)
|
||||||
.map_err(pattern_format_error_to_csr_error)?;
|
.map_err(pattern_format_error_to_csr_error)?;
|
||||||
Self::try_from_pattern_and_values(Arc::new(pattern), values)
|
Self::try_from_pattern_and_values(pattern, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||||
///
|
///
|
||||||
/// Returns an error if the number of values does not match the number of minor indices
|
/// Returns an error if the number of values does not match the number of minor indices
|
||||||
/// in the pattern.
|
/// in the pattern.
|
||||||
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||||
-> Result<Self, SparseFormatError> {
|
-> Result<Self, SparseFormatError> {
|
||||||
if pattern.nnz() == values.len() {
|
if pattern.nnz() == values.len() {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -214,7 +213,7 @@ impl<T> CsrMatrix<T> {
|
|||||||
/// An iterator over rows in the matrix.
|
/// An iterator over rows in the matrix.
|
||||||
pub fn row_iter(&self) -> CsrRowIter<T> {
|
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||||
CsrRowIter {
|
CsrRowIter {
|
||||||
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
|
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,7 +259,7 @@ impl<T> CsrMatrix<T> {
|
|||||||
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
|
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
|
||||||
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
||||||
/// times in memory.
|
/// times in memory.
|
||||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
pub fn pattern(&self) -> &SparsityPattern {
|
||||||
self.cs.pattern()
|
self.cs.pattern()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,26 +5,25 @@ use crate::pattern::SparsityPattern;
|
|||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use core::{mem, iter};
|
use core::{mem, iter};
|
||||||
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
|
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
|
||||||
use std::sync::Arc;
|
|
||||||
use std::fmt::{Display, Formatter};
|
use std::fmt::{Display, Formatter};
|
||||||
use crate::ops::serial::spsolve_csc_lower_triangular;
|
use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||||
use crate::ops::Op;
|
use crate::ops::Op;
|
||||||
|
|
||||||
pub struct CscSymbolicCholesky {
|
pub struct CscSymbolicCholesky {
|
||||||
// Pattern of the original matrix that was decomposed
|
// Pattern of the original matrix that was decomposed
|
||||||
m_pattern: Arc<SparsityPattern>,
|
m_pattern: SparsityPattern,
|
||||||
l_pattern: SparsityPattern,
|
l_pattern: SparsityPattern,
|
||||||
// u in this context is L^T, so that M = L L^T
|
// u in this context is L^T, so that M = L L^T
|
||||||
u_pattern: SparsityPattern
|
u_pattern: SparsityPattern
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CscSymbolicCholesky {
|
impl CscSymbolicCholesky {
|
||||||
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self {
|
pub fn factor(pattern: SparsityPattern) -> Self {
|
||||||
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
||||||
"Major and minor dimensions must be the same (square matrix).");
|
"Major and minor dimensions must be the same (square matrix).");
|
||||||
let (l_pattern, u_pattern) = nonzero_pattern(&*pattern);
|
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
||||||
Self {
|
Self {
|
||||||
m_pattern: Arc::clone(pattern),
|
m_pattern: pattern,
|
||||||
l_pattern,
|
l_pattern,
|
||||||
u_pattern,
|
u_pattern,
|
||||||
}
|
}
|
||||||
@ -37,7 +36,7 @@ impl CscSymbolicCholesky {
|
|||||||
|
|
||||||
pub struct CscCholesky<T> {
|
pub struct CscCholesky<T> {
|
||||||
// Pattern of the original matrix
|
// Pattern of the original matrix
|
||||||
m_pattern: Arc<SparsityPattern>,
|
m_pattern: SparsityPattern,
|
||||||
l_factor: CscMatrix<T>,
|
l_factor: CscMatrix<T>,
|
||||||
u_pattern: SparsityPattern,
|
u_pattern: SparsityPattern,
|
||||||
work_x: Vec<T>,
|
work_x: Vec<T>,
|
||||||
@ -66,7 +65,7 @@ impl<T: RealField> CscCholesky<T> {
|
|||||||
|
|
||||||
let l_nnz = symbolic.l_pattern.nnz();
|
let l_nnz = symbolic.l_pattern.nnz();
|
||||||
let l_values = vec![T::zero(); l_nnz];
|
let l_values = vec![T::zero(); l_nnz];
|
||||||
let l_factor = CscMatrix::try_from_pattern_and_values(Arc::new(symbolic.l_pattern), l_values)
|
let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
||||||
@ -86,7 +85,7 @@ impl<T: RealField> CscCholesky<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
||||||
let symbolic = CscSymbolicCholesky::factor(&*matrix.pattern());
|
let symbolic = CscSymbolicCholesky::factor(matrix.pattern().clone());
|
||||||
Self::factor_numerical(symbolic, matrix.values())
|
Self::factor_numerical(symbolic, matrix.values())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
|||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
||||||
DMatrixSlice, DMatrix, Dynamic};
|
DMatrixSlice, DMatrix, Dynamic};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
use nalgebra::base::storage::Storage;
|
use nalgebra::base::storage::Storage;
|
||||||
|
|
||||||
@ -48,11 +47,7 @@ macro_rules! impl_sp_plus_minus {
|
|||||||
impl_bin_op!($trait, $method,
|
impl_bin_op!($trait, $method,
|
||||||
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
// If both matrices have the same pattern, then we can immediately re-use it
|
// If both matrices have the same pattern, then we can immediately re-use it
|
||||||
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
let pattern = spadd_pattern(a.pattern(), b.pattern());
|
||||||
Arc::clone(a.pattern())
|
|
||||||
} else {
|
|
||||||
Arc::new(spadd_pattern(a.pattern(), b.pattern()))
|
|
||||||
};
|
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||||
@ -64,24 +59,12 @@ macro_rules! impl_sp_plus_minus {
|
|||||||
|
|
||||||
impl_bin_op!($trait, $method,
|
impl_bin_op!($trait, $method,
|
||||||
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
let mut a = a;
|
&a $sign b
|
||||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
|
||||||
$spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap();
|
|
||||||
a
|
|
||||||
} else {
|
|
||||||
&a $sign b
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
impl_bin_op!($trait, $method,
|
impl_bin_op!($trait, $method,
|
||||||
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
let mut b = b;
|
a $sign &b
|
||||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
|
||||||
$spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap();
|
|
||||||
b
|
|
||||||
} else {
|
|
||||||
a $sign &b
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
a $sign &b
|
a $sign &b
|
||||||
@ -107,7 +90,7 @@ macro_rules! impl_spmm {
|
|||||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values)
|
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
$spmm_fn(T::zero(),
|
$spmm_fn(T::zero(),
|
||||||
&mut result,
|
&mut result,
|
||||||
@ -154,7 +137,7 @@ macro_rules! impl_scalar_mul {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
|
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
|
||||||
.collect();
|
.collect();
|
||||||
$matrix_type::try_from_pattern_and_values(Arc::clone(a.pattern()), values).unwrap()
|
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
|
||||||
});
|
});
|
||||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||||
a * &b
|
a * &b
|
||||||
@ -251,7 +234,7 @@ macro_rules! impl_div {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
|
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
|
||||||
.collect();
|
.collect();
|
||||||
$matrix_type::try_from_pattern_and_values(Arc::clone(matrix.pattern()), new_values)
|
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
|
||||||
|
@ -4,7 +4,6 @@ use crate::ops::serial::{OperationErrorType, OperationError};
|
|||||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use crate::SparseEntryMut;
|
use crate::SparseEntryMut;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
fn spmm_cs_unexpected_entry() -> OperationError {
|
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||||
OperationError::from_type_and_message(
|
OperationError::from_type_and_message(
|
||||||
@ -73,64 +72,52 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
|||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
|
match a {
|
||||||
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
|
Op::NoOp(a) => {
|
||||||
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
|
||||||
// so we only need to sum the value arrays
|
|
||||||
// TODO: Test this fast path
|
|
||||||
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
|
|
||||||
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
|
||||||
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
match a {
|
|
||||||
Op::NoOp(a) => {
|
|
||||||
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
|
|
||||||
if beta != T::one() {
|
|
||||||
for c_ij in c_lane_i.values_mut() {
|
|
||||||
*c_ij *= beta.inlined_clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
|
|
||||||
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
|
|
||||||
|
|
||||||
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
|
|
||||||
// TODO: Use exponential search instead of linear search.
|
|
||||||
// If C has substantially more entries in the row than A, then a line search
|
|
||||||
// will needlessly visit many entries in C.
|
|
||||||
let (c_idx, _) = c_minors.iter()
|
|
||||||
.enumerate()
|
|
||||||
.find(|(_, c_col)| *c_col == a_col)
|
|
||||||
.ok_or_else(spadd_cs_unexpected_entry)?;
|
|
||||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
|
||||||
c_minors = &c_minors[c_idx ..];
|
|
||||||
c_vals = &mut c_vals[c_idx ..];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Op::Transpose(a) => {
|
|
||||||
if beta != T::one() {
|
if beta != T::one() {
|
||||||
for c_ij in c.values_mut() {
|
for c_ij in c_lane_i.values_mut() {
|
||||||
*c_ij *= beta.inlined_clone();
|
*c_ij *= beta.inlined_clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i, a_lane_i) in a.lane_iter().enumerate() {
|
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
|
||||||
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
|
||||||
let a_val = a_val.inlined_clone();
|
|
||||||
let alpha = alpha.inlined_clone();
|
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
|
||||||
match c.get_entry_mut(j, i).unwrap() {
|
// TODO: Use exponential search instead of linear search.
|
||||||
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
// If C has substantially more entries in the row than A, then a line search
|
||||||
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
// will needlessly visit many entries in C.
|
||||||
}
|
let (c_idx, _) = c_minors.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == a_col)
|
||||||
|
.ok_or_else(spadd_cs_unexpected_entry)?;
|
||||||
|
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||||
|
c_minors = &c_minors[c_idx ..];
|
||||||
|
c_vals = &mut c_vals[c_idx ..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Op::Transpose(a) => {
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, a_lane_i) in a.lane_iter().enumerate() {
|
||||||
|
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||||
|
let a_val = a_val.inlined_clone();
|
||||||
|
let alpha = alpha.inlined_clone();
|
||||||
|
match c.get_entry_mut(j, i).unwrap() {
|
||||||
|
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
||||||
|
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
|
@ -11,7 +11,6 @@ use std::iter::{repeat};
|
|||||||
use proptest::sample::{Index};
|
use proptest::sample::{Index};
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::pattern::SparsityPattern;
|
use crate::pattern::SparsityPattern;
|
||||||
use std::sync::Arc;
|
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||||
@ -291,7 +290,7 @@ where
|
|||||||
(Just(pattern), values)
|
(Just(pattern), values)
|
||||||
})
|
})
|
||||||
.prop_map(|(pattern, values)| {
|
.prop_map(|(pattern, values)| {
|
||||||
CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
.expect("Internal error: Generated CsrMatrix is invalid")
|
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -313,7 +312,7 @@ pub fn csc<T>(value_strategy: T,
|
|||||||
(Just(pattern), values)
|
(Just(pattern), values)
|
||||||
})
|
})
|
||||||
.prop_map(|(pattern, values)| {
|
.prop_map(|(pattern, values)| {
|
||||||
CscMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||||
.expect("Internal error: Generated CscMatrix is invalid")
|
.expect("Internal error: Generated CscMatrix is invalid")
|
||||||
})
|
})
|
||||||
}
|
}
|
@ -17,12 +17,11 @@ use proptest::prelude::*;
|
|||||||
use matrixcompare::prop_assert_matrix_eq;
|
use matrixcompare::prop_assert_matrix_eq;
|
||||||
|
|
||||||
use std::panic::catch_unwind;
|
use std::panic::catch_unwind;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||||
Arc::new(pattern.clone()),
|
pattern.clone(),
|
||||||
vec![1; pattern.nnz()])
|
vec![1; pattern.nnz()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
DMatrix::from(&boolean_csr)
|
DMatrix::from(&boolean_csr)
|
||||||
@ -31,7 +30,7 @@ fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
|||||||
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
||||||
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
let boolean_csc = CscMatrix::try_from_pattern_and_values(
|
let boolean_csc = CscMatrix::try_from_pattern_and_values(
|
||||||
Arc::new(pattern.clone()),
|
pattern.clone(),
|
||||||
vec![1; pattern.nnz()])
|
vec![1; pattern.nnz()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
DMatrix::from(&boolean_csc)
|
DMatrix::from(&boolean_csc)
|
||||||
@ -137,8 +136,8 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
|
|||||||
let beta = value_strategy.clone();
|
let beta = value_strategy.clone();
|
||||||
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
||||||
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||||
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
|
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
|
||||||
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
|
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
|
||||||
|
|
||||||
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
|
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
|
||||||
SpaddCsrArgs { c, beta, alpha, a }
|
SpaddCsrArgs { c, beta, alpha, a }
|
||||||
@ -222,15 +221,12 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
|||||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||||
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
||||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||||
let a_pattern = Arc::new(a_pattern);
|
|
||||||
let b_pattern = Arc::new(b_pattern);
|
|
||||||
let c_pattern = Arc::new(c_pattern);
|
|
||||||
let a = a_values.prop_map(move |values|
|
let a = a_values.prop_map(move |values|
|
||||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&a_pattern), values).unwrap());
|
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
||||||
let b = b_values.prop_map(move |values|
|
let b = b_values.prop_map(move |values|
|
||||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&b_pattern), values).unwrap());
|
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap());
|
||||||
let c = c_values.prop_map(move |values|
|
let c = c_values.prop_map(move |values|
|
||||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&c_pattern), values).unwrap());
|
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap());
|
||||||
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
||||||
@ -383,16 +379,16 @@ proptest! {
|
|||||||
// corresponding to a and b, and convert them to dense matrices.
|
// corresponding to a and b, and convert them to dense matrices.
|
||||||
// The sum of these dense matrices will then have non-zeros in exactly the same locations
|
// The sum of these dense matrices will then have non-zeros in exactly the same locations
|
||||||
// as the result of "adding" the sparsity patterns
|
// as the result of "adding" the sparsity patterns
|
||||||
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
|
let a_csr = CsrMatrix::try_from_pattern_and_values(a.clone(), vec![1; a.nnz()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let a_dense = DMatrix::from(&a_csr);
|
let a_dense = DMatrix::from(&a_csr);
|
||||||
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
|
let b_csr = CsrMatrix::try_from_pattern_and_values(b.clone(), vec![1; b.nnz()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let b_dense = DMatrix::from(&b_csr);
|
let b_dense = DMatrix::from(&b_csr);
|
||||||
let c_dense = a_dense + b_dense;
|
let c_dense = a_dense + b_dense;
|
||||||
let c_csr = CsrMatrix::from(&c_dense);
|
let c_csr = CsrMatrix::from(&c_dense);
|
||||||
|
|
||||||
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref());
|
prop_assert_eq!(&pattern_result, c_csr.pattern());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -492,16 +488,16 @@ proptest! {
|
|||||||
// corresponding to a and b, and convert them to dense matrices.
|
// corresponding to a and b, and convert them to dense matrices.
|
||||||
// The product of these dense matrices will then have non-zeros in exactly the same locations
|
// The product of these dense matrices will then have non-zeros in exactly the same locations
|
||||||
// as the result of "multiplying" the sparsity patterns
|
// as the result of "multiplying" the sparsity patterns
|
||||||
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
|
let a_csr = CsrMatrix::try_from_pattern_and_values(a.clone(), vec![1; a.nnz()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let a_dense = DMatrix::from(&a_csr);
|
let a_dense = DMatrix::from(&a_csr);
|
||||||
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
|
let b_csr = CsrMatrix::try_from_pattern_and_values(b.clone(), vec![1; b.nnz()])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let b_dense = DMatrix::from(&b_csr);
|
let b_dense = DMatrix::from(&b_csr);
|
||||||
let c_dense = a_dense * b_dense;
|
let c_dense = a_dense * b_dense;
|
||||||
let c_csr = CsrMatrix::from(&c_dense);
|
let c_csr = CsrMatrix::from(&c_dense);
|
||||||
|
|
||||||
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
|
prop_assert_eq!(&c_pattern, c_csr.pattern());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
Loading…
Reference in New Issue
Block a user