Add prealloc suffix to spmm_csr and spadd_csr
The suffix is intended to communicate that these methods assume `preallocated` storage, i.e. they try to store the result in a matrix which already has the correct sparsity pattern for the operation.
This commit is contained in:
parent
4af3fcbdd3
commit
66cbd26702
|
@ -1,7 +1,7 @@
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul};
|
||||||
use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr};
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_pattern, spmm_pattern, spmm_csr_prealloc};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -21,8 +21,8 @@ where
|
||||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
spadd_csr(T::zero(), &mut result, T::one(), Op::NoOp(&self)).unwrap();
|
spadd_csr_prealloc(T::zero(), &mut result, T::one(), Op::NoOp(&self)).unwrap();
|
||||||
spadd_csr(T::one(), &mut result, T::one(), Op::NoOp(&rhs)).unwrap();
|
spadd_csr_prealloc(T::one(), &mut result, T::one(), Op::NoOp(&rhs)).unwrap();
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ where
|
||||||
|
|
||||||
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||||
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
||||||
spadd_csr(T::one(), &mut self, T::one(), Op::NoOp(rhs)).unwrap();
|
spadd_csr_prealloc(T::one(), &mut self, T::one(), Op::NoOp(rhs)).unwrap();
|
||||||
self
|
self
|
||||||
} else {
|
} else {
|
||||||
&self + rhs
|
&self + rhs
|
||||||
|
@ -90,7 +90,7 @@ impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T>
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
spmm_csr(T::zero(),
|
spmm_csr_prealloc(T::zero(),
|
||||||
&mut result,
|
&mut result,
|
||||||
T::one(),
|
T::one(),
|
||||||
Op::NoOp(a),
|
Op::NoOp(a),
|
||||||
|
|
|
@ -87,11 +87,11 @@ fn spadd_csr_unexpected_entry() -> OperationError {
|
||||||
///
|
///
|
||||||
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||||
/// returned.
|
/// returned.
|
||||||
pub fn spadd_csr<T>(beta: T,
|
pub fn spadd_csr_prealloc<T>(beta: T,
|
||||||
c: &mut CsrMatrix<T>,
|
c: &mut CsrMatrix<T>,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
a: Op<&CsrMatrix<T>>)
|
a: Op<&CsrMatrix<T>>)
|
||||||
-> Result<(), OperationError>
|
-> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
|
@ -161,7 +161,7 @@ fn spmm_csr_unexpected_entry() -> OperationError {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
pub fn spmm_csr<T>(
|
pub fn spmm_csr_prealloc<T>(
|
||||||
beta: T,
|
beta: T,
|
||||||
c: &mut CsrMatrix<T>,
|
c: &mut CsrMatrix<T>,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
|
@ -218,7 +218,7 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
spmm_csr(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
spmm_csr_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
PROPTEST_I32_VALUE_STRATEGY};
|
PROPTEST_I32_VALUE_STRATEGY};
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr_prealloc, spmm_csr_prealloc};
|
||||||
use nalgebra_sparse::ops::{Op};
|
use nalgebra_sparse::ops::{Op};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||||
|
@ -78,7 +78,7 @@ struct SpaddCsrArgs<T> {
|
||||||
a: Op<CsrMatrix<T>>,
|
a: Op<CsrMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
|
||||||
spadd_pattern_strategy()
|
spadd_pattern_strategy()
|
||||||
|
@ -150,7 +150,7 @@ struct SpmmCsrArgs<T> {
|
||||||
b: Op<CsrMatrix<T>>,
|
b: Op<CsrMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
spmm_pattern_strategy()
|
spmm_pattern_strategy()
|
||||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||||
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||||
|
@ -287,12 +287,12 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, a } in spadd_csr_args_strategy()) {
|
fn spadd_csr_prealloc_test(SpaddCsrArgs { c, beta, alpha, a } in spadd_csr_prealloc_args_strategy()) {
|
||||||
// Test that we get the expected result by comparing to an equivalent dense operation
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
|
|
||||||
let mut c_sparse = c.clone();
|
let mut c_sparse = c.clone();
|
||||||
spadd_csr(beta, &mut c_sparse, alpha, a.as_ref()).unwrap();
|
spadd_csr_prealloc(beta, &mut c_sparse, alpha, a.as_ref()).unwrap();
|
||||||
|
|
||||||
let mut c_dense = DMatrix::from(&c);
|
let mut c_dense = DMatrix::from(&c);
|
||||||
let op_a_dense = match a {
|
let op_a_dense = match a {
|
||||||
|
@ -363,13 +363,13 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, a, b }
|
fn spmm_csr_prealloc_test(SpmmCsrArgs { c, beta, alpha, a, b }
|
||||||
in spmm_csr_args_strategy()
|
in spmm_csr_prealloc_args_strategy()
|
||||||
) {
|
) {
|
||||||
// Test that we get the expected result by comparing to an equivalent dense operation
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
let mut c_sparse = c.clone();
|
let mut c_sparse = c.clone();
|
||||||
spmm_csr(beta, &mut c_sparse, alpha, a.as_ref(), b.as_ref()).unwrap();
|
spmm_csr_prealloc(beta, &mut c_sparse, alpha, a.as_ref(), b.as_ref()).unwrap();
|
||||||
|
|
||||||
let mut c_dense = DMatrix::from(&c);
|
let mut c_dense = DMatrix::from(&c);
|
||||||
let op_a_dense = match a {
|
let op_a_dense = match a {
|
||||||
|
@ -386,7 +386,7 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_panics_on_dim_mismatch(
|
fn spmm_csr_prealloc_panics_on_dim_mismatch(
|
||||||
(alpha, beta, c, a, b)
|
(alpha, beta, c, a, b)
|
||||||
in (PROPTEST_I32_VALUE_STRATEGY,
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
PROPTEST_I32_VALUE_STRATEGY,
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
@ -424,7 +424,7 @@ proptest! {
|
||||||
|
|
||||||
let result = catch_unwind(|| {
|
let result = catch_unwind(|| {
|
||||||
let mut spmm_result = c.clone();
|
let mut spmm_result = c.clone();
|
||||||
spmm_csr(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()).unwrap();
|
spmm_csr_prealloc(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()).unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
|
@ -432,7 +432,7 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spadd_csr_panics_on_dim_mismatch(
|
fn spadd_csr_prealloc_panics_on_dim_mismatch(
|
||||||
(alpha, beta, c, op_a)
|
(alpha, beta, c, op_a)
|
||||||
in (PROPTEST_I32_VALUE_STRATEGY,
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
PROPTEST_I32_VALUE_STRATEGY,
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
@ -456,7 +456,7 @@ proptest! {
|
||||||
|
|
||||||
let result = catch_unwind(|| {
|
let result = catch_unwind(|| {
|
||||||
let mut spmm_result = c.clone();
|
let mut spmm_result = c.clone();
|
||||||
spadd_csr(beta, &mut spmm_result, alpha, op_a.as_ref()).unwrap();
|
spadd_csr_prealloc(beta, &mut spmm_result, alpha, op_a.as_ref()).unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
|
|
Loading…
Reference in New Issue