forked from M-Labs/nalgebra
Replace Arc<SparsityPattern> with SparsityPattern
After much deliberation, I have come to the conclusion that the benefits do not really outweigh the added complexity. Even though the added complexity is relatively minor, it makes it somewhat more complicated to inter-op with other sparse linear algebra libraries in the future.
This commit is contained in:
parent
9b46a43c7f
commit
e655fed4fa
@ -1,6 +1,5 @@
|
||||
use std::mem::replace;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use num_traits::One;
|
||||
|
||||
@ -18,7 +17,7 @@ use crate::pattern::SparsityPattern;
|
||||
/// is obtained by associating columns with the major dimension.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsMatrix<T> {
|
||||
sparsity_pattern: Arc<SparsityPattern>,
|
||||
sparsity_pattern: SparsityPattern,
|
||||
values: Vec<T>
|
||||
}
|
||||
|
||||
@ -27,13 +26,13 @@ impl<T> CsMatrix<T> {
|
||||
#[inline]
|
||||
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
||||
Self {
|
||||
sparsity_pattern: Arc::new(SparsityPattern::new(major_dim, minor_dim)),
|
||||
sparsity_pattern: SparsityPattern::new(major_dim, minor_dim),
|
||||
values: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||
pub fn pattern(&self) -> &SparsityPattern {
|
||||
&self.sparsity_pattern
|
||||
}
|
||||
|
||||
@ -50,24 +49,24 @@ impl<T> CsMatrix<T> {
|
||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||
#[inline]
|
||||
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||
let pattern = self.pattern().as_ref();
|
||||
let pattern = self.pattern();
|
||||
(pattern.major_offsets(), pattern.minor_indices(), &self.values)
|
||||
}
|
||||
|
||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||
#[inline]
|
||||
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
let pattern = self.sparsity_pattern.as_ref();
|
||||
let pattern = &mut self.sparsity_pattern;
|
||||
(pattern.major_offsets(), pattern.minor_indices(), &mut self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn pattern_and_values_mut(&mut self) -> (&Arc<SparsityPattern>, &mut [T]) {
|
||||
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||
(&self.sparsity_pattern, &mut self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
||||
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||
-> Self {
|
||||
assert_eq!(pattern.nnz(), values.len(), "Internal error: consumers should verify shape compatibility.");
|
||||
Self {
|
||||
@ -84,25 +83,14 @@ impl<T> CsMatrix<T> {
|
||||
Some(row_begin .. row_end)
|
||||
}
|
||||
|
||||
pub fn take_pattern_and_values(self) -> (Arc<SparsityPattern>, Vec<T>) {
|
||||
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||
(self.sparsity_pattern, self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||
// Take an Arc to the pattern, which might be the sole reference to the data after
|
||||
// taking the values. This is important, because it might let us avoid cloning the data
|
||||
// further below.
|
||||
let pattern = self.sparsity_pattern;
|
||||
let values = self.values;
|
||||
|
||||
// Try to take the pattern out of the `Arc` if possible,
|
||||
// otherwise clone the pattern.
|
||||
let owned_pattern = Arc::try_unwrap(pattern)
|
||||
.unwrap_or_else(|arc| SparsityPattern::clone(&*arc));
|
||||
let (offsets, indices) = owned_pattern.disassemble();
|
||||
|
||||
(offsets, indices, values)
|
||||
let (offsets, indices) = self.sparsity_pattern.disassemble();
|
||||
(offsets, indices, self.values)
|
||||
}
|
||||
|
||||
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
||||
@ -151,12 +139,12 @@ impl<T> CsMatrix<T> {
|
||||
|
||||
#[inline]
|
||||
pub fn lane_iter(&self) -> CsLaneIter<T> {
|
||||
CsLaneIter::new(self.pattern().as_ref(), self.values())
|
||||
CsLaneIter::new(self.pattern(), self.values())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
||||
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values)
|
||||
CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -190,7 +178,7 @@ impl<T> CsMatrix<T> {
|
||||
new_indices)
|
||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||
|
||||
Self::from_pattern_and_values(Arc::new(new_pattern), new_values)
|
||||
Self::from_pattern_and_values(new_pattern, new_values)
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,7 +193,7 @@ impl<T: Scalar + One> CsMatrix<T> {
|
||||
// TODO: We should skip checks here
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices)
|
||||
.unwrap();
|
||||
Self::from_pattern_and_values(Arc::new(pattern), values)
|
||||
Self::from_pattern_and_values(pattern, values)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,6 @@ use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatter
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::slice::{IterMut, Iter};
|
||||
use num_traits::{One};
|
||||
use nalgebra::Scalar;
|
||||
@ -95,14 +94,14 @@ impl<T> CscMatrix<T> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_cols, num_rows, col_offsets, row_indices)
|
||||
.map_err(pattern_format_error_to_csc_error)?;
|
||||
Self::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
Self::try_from_pattern_and_values(pattern, values)
|
||||
}
|
||||
|
||||
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
|
||||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
||||
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||
-> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
@ -212,7 +211,7 @@ impl<T> CscMatrix<T> {
|
||||
/// An iterator over columns in the matrix.
|
||||
pub fn col_iter(&self) -> CscColIter<T> {
|
||||
CscColIter {
|
||||
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
||||
}
|
||||
}
|
||||
|
||||
@ -258,7 +257,7 @@ impl<T> CscMatrix<T> {
|
||||
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
|
||||
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
||||
/// times in memory.
|
||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||
pub fn pattern(&self) -> &SparsityPattern {
|
||||
self.cs.pattern()
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@ use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::{One};
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::slice::{IterMut, Iter};
|
||||
|
||||
/// A CSR representation of a sparse matrix.
|
||||
@ -97,14 +96,14 @@ impl<T> CsrMatrix<T> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_rows, num_cols, row_offsets, col_indices)
|
||||
.map_err(pattern_format_error_to_csr_error)?;
|
||||
Self::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
Self::try_from_pattern_and_values(pattern, values)
|
||||
}
|
||||
|
||||
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>)
|
||||
pub fn try_from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>)
|
||||
-> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
@ -214,7 +213,7 @@ impl<T> CsrMatrix<T> {
|
||||
/// An iterator over rows in the matrix.
|
||||
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||
CsrRowIter {
|
||||
lane_iter: CsLaneIter::new(self.pattern().as_ref(), self.values())
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values())
|
||||
}
|
||||
}
|
||||
|
||||
@ -260,7 +259,7 @@ impl<T> CsrMatrix<T> {
|
||||
/// The sparsity pattern is stored internally inside an `Arc`. This allows users to re-use
|
||||
/// the same sparsity pattern for multiple matrices without storing the same pattern multiple
|
||||
/// times in memory.
|
||||
pub fn pattern(&self) -> &Arc<SparsityPattern> {
|
||||
pub fn pattern(&self) -> &SparsityPattern {
|
||||
self.cs.pattern()
|
||||
}
|
||||
|
||||
|
@ -5,26 +5,25 @@ use crate::pattern::SparsityPattern;
|
||||
use crate::csc::CscMatrix;
|
||||
use core::{mem, iter};
|
||||
use nalgebra::{Scalar, RealField, DMatrixSlice, DMatrixSliceMut, DMatrix};
|
||||
use std::sync::Arc;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||
use crate::ops::Op;
|
||||
|
||||
pub struct CscSymbolicCholesky {
|
||||
// Pattern of the original matrix that was decomposed
|
||||
m_pattern: Arc<SparsityPattern>,
|
||||
m_pattern: SparsityPattern,
|
||||
l_pattern: SparsityPattern,
|
||||
// u in this context is L^T, so that M = L L^T
|
||||
u_pattern: SparsityPattern
|
||||
}
|
||||
|
||||
impl CscSymbolicCholesky {
|
||||
pub fn factor(pattern: &Arc<SparsityPattern>) -> Self {
|
||||
pub fn factor(pattern: SparsityPattern) -> Self {
|
||||
assert_eq!(pattern.major_dim(), pattern.minor_dim(),
|
||||
"Major and minor dimensions must be the same (square matrix).");
|
||||
let (l_pattern, u_pattern) = nonzero_pattern(&*pattern);
|
||||
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
||||
Self {
|
||||
m_pattern: Arc::clone(pattern),
|
||||
m_pattern: pattern,
|
||||
l_pattern,
|
||||
u_pattern,
|
||||
}
|
||||
@ -37,7 +36,7 @@ impl CscSymbolicCholesky {
|
||||
|
||||
pub struct CscCholesky<T> {
|
||||
// Pattern of the original matrix
|
||||
m_pattern: Arc<SparsityPattern>,
|
||||
m_pattern: SparsityPattern,
|
||||
l_factor: CscMatrix<T>,
|
||||
u_pattern: SparsityPattern,
|
||||
work_x: Vec<T>,
|
||||
@ -66,7 +65,7 @@ impl<T: RealField> CscCholesky<T> {
|
||||
|
||||
let l_nnz = symbolic.l_pattern.nnz();
|
||||
let l_values = vec![T::zero(); l_nnz];
|
||||
let l_factor = CscMatrix::try_from_pattern_and_values(Arc::new(symbolic.l_pattern), l_values)
|
||||
let l_factor = CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values)
|
||||
.unwrap();
|
||||
|
||||
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
||||
@ -86,7 +85,7 @@ impl<T: RealField> CscCholesky<T> {
|
||||
}
|
||||
|
||||
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
||||
let symbolic = CscSymbolicCholesky::factor(&*matrix.pattern());
|
||||
let symbolic = CscSymbolicCholesky::factor(matrix.pattern().clone());
|
||||
Self::factor_numerical(symbolic, matrix.values())
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@ use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
||||
DMatrixSlice, DMatrix, Dynamic};
|
||||
use num_traits::{Zero, One};
|
||||
use std::sync::Arc;
|
||||
use crate::ops::{Op};
|
||||
use nalgebra::base::storage::Storage;
|
||||
|
||||
@ -48,11 +47,7 @@ macro_rules! impl_sp_plus_minus {
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
// If both matrices have the same pattern, then we can immediately re-use it
|
||||
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||
Arc::clone(a.pattern())
|
||||
} else {
|
||||
Arc::new(spadd_pattern(a.pattern(), b.pattern()))
|
||||
};
|
||||
let pattern = spadd_pattern(a.pattern(), b.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||
@ -64,24 +59,12 @@ macro_rules! impl_sp_plus_minus {
|
||||
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
let mut a = a;
|
||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||
$spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap();
|
||||
a
|
||||
} else {
|
||||
&a $sign b
|
||||
}
|
||||
});
|
||||
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
let mut b = b;
|
||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||
$spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap();
|
||||
b
|
||||
} else {
|
||||
a $sign &b
|
||||
}
|
||||
});
|
||||
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
a $sign &b
|
||||
@ -107,7 +90,7 @@ macro_rules! impl_spmm {
|
||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||
.unwrap();
|
||||
$spmm_fn(T::zero(),
|
||||
&mut result,
|
||||
@ -154,7 +137,7 @@ macro_rules! impl_scalar_mul {
|
||||
.iter()
|
||||
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
|
||||
.collect();
|
||||
$matrix_type::try_from_pattern_and_values(Arc::clone(a.pattern()), values).unwrap()
|
||||
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
|
||||
});
|
||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||
a * &b
|
||||
@ -251,7 +234,7 @@ macro_rules! impl_div {
|
||||
.iter()
|
||||
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
|
||||
.collect();
|
||||
$matrix_type::try_from_pattern_and_values(Arc::clone(matrix.pattern()), new_values)
|
||||
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
|
||||
.unwrap()
|
||||
});
|
||||
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
|
||||
|
@ -4,7 +4,6 @@ use crate::ops::serial::{OperationErrorType, OperationError};
|
||||
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::SparseEntryMut;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||
OperationError::from_type_and_message(
|
||||
@ -73,17 +72,6 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
|
||||
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
|
||||
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
||||
// so we only need to sum the value arrays
|
||||
// TODO: Test this fast path
|
||||
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
|
||||
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
||||
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
match a {
|
||||
Op::NoOp(a) => {
|
||||
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
|
||||
@ -130,7 +118,6 @@ pub fn spadd_cs_prealloc<T>(beta: T,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||
|
@ -11,7 +11,6 @@ use std::iter::{repeat};
|
||||
use proptest::sample::{Index};
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::pattern::SparsityPattern;
|
||||
use std::sync::Arc;
|
||||
use crate::csc::CscMatrix;
|
||||
|
||||
fn dense_row_major_coord_strategy(nrows: usize, ncols: usize, nnz: usize)
|
||||
@ -291,7 +290,7 @@ where
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||
})
|
||||
}
|
||||
@ -313,7 +312,7 @@ pub fn csc<T>(value_strategy: T,
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CscMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CscMatrix is invalid")
|
||||
})
|
||||
}
|
@ -17,12 +17,11 @@ use proptest::prelude::*;
|
||||
use matrixcompare::prop_assert_matrix_eq;
|
||||
|
||||
use std::panic::catch_unwind;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||
Arc::new(pattern.clone()),
|
||||
pattern.clone(),
|
||||
vec![1; pattern.nnz()])
|
||||
.unwrap();
|
||||
DMatrix::from(&boolean_csr)
|
||||
@ -31,7 +30,7 @@ fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
||||
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
let boolean_csc = CscMatrix::try_from_pattern_and_values(
|
||||
Arc::new(pattern.clone()),
|
||||
pattern.clone(),
|
||||
vec![1; pattern.nnz()])
|
||||
.unwrap();
|
||||
DMatrix::from(&boolean_csc)
|
||||
@ -137,8 +136,8 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
|
||||
let beta = value_strategy.clone();
|
||||
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
||||
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
|
||||
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
|
||||
let c = CsrMatrix::try_from_pattern_and_values(c_pattern, c_values).unwrap();
|
||||
let a = CsrMatrix::try_from_pattern_and_values(a_pattern, a_values).unwrap();
|
||||
|
||||
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
|
||||
SpaddCsrArgs { c, beta, alpha, a }
|
||||
@ -222,15 +221,12 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||
let a_pattern = Arc::new(a_pattern);
|
||||
let b_pattern = Arc::new(b_pattern);
|
||||
let c_pattern = Arc::new(c_pattern);
|
||||
let a = a_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&a_pattern), values).unwrap());
|
||||
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
||||
let b = b_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&b_pattern), values).unwrap());
|
||||
CsrMatrix::try_from_pattern_and_values(b_pattern.clone(), values).unwrap());
|
||||
let c = c_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&c_pattern), values).unwrap());
|
||||
CsrMatrix::try_from_pattern_and_values(c_pattern.clone(), values).unwrap());
|
||||
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
||||
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
||||
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
||||
@ -383,16 +379,16 @@ proptest! {
|
||||
// corresponding to a and b, and convert them to dense matrices.
|
||||
// The sum of these dense matrices will then have non-zeros in exactly the same locations
|
||||
// as the result of "adding" the sparsity patterns
|
||||
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
|
||||
let a_csr = CsrMatrix::try_from_pattern_and_values(a.clone(), vec![1; a.nnz()])
|
||||
.unwrap();
|
||||
let a_dense = DMatrix::from(&a_csr);
|
||||
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
|
||||
let b_csr = CsrMatrix::try_from_pattern_and_values(b.clone(), vec![1; b.nnz()])
|
||||
.unwrap();
|
||||
let b_dense = DMatrix::from(&b_csr);
|
||||
let c_dense = a_dense + b_dense;
|
||||
let c_csr = CsrMatrix::from(&c_dense);
|
||||
|
||||
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref());
|
||||
prop_assert_eq!(&pattern_result, c_csr.pattern());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -492,16 +488,16 @@ proptest! {
|
||||
// corresponding to a and b, and convert them to dense matrices.
|
||||
// The product of these dense matrices will then have non-zeros in exactly the same locations
|
||||
// as the result of "multiplying" the sparsity patterns
|
||||
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
|
||||
let a_csr = CsrMatrix::try_from_pattern_and_values(a.clone(), vec![1; a.nnz()])
|
||||
.unwrap();
|
||||
let a_dense = DMatrix::from(&a_csr);
|
||||
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
|
||||
let b_csr = CsrMatrix::try_from_pattern_and_values(b.clone(), vec![1; b.nnz()])
|
||||
.unwrap();
|
||||
let b_dense = DMatrix::from(&b_csr);
|
||||
let c_dense = a_dense * b_dense;
|
||||
let c_csr = CsrMatrix::from(&c_dense);
|
||||
|
||||
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
|
||||
prop_assert_eq!(&c_pattern, c_csr.pattern());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
Loading…
Reference in New Issue
Block a user