forked from M-Labs/nalgebra
Implement CSR-CSR matrix multiplication
This commit is contained in:
parent
d9cfe5cb3e
commit
b25848838b
@ -1,7 +1,7 @@
|
||||
use crate::csr::CsrMatrix;
|
||||
|
||||
use std::ops::Add;
|
||||
use crate::ops::serial::{spadd_csr, spadd_build_pattern};
|
||||
use std::ops::{Add, Mul};
|
||||
use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||
use num_traits::{Zero, One};
|
||||
use std::sync::Arc;
|
||||
@ -65,4 +65,43 @@ where
|
||||
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
||||
self + &rhs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper macro for implementing matrix multiplication for different matrix types
|
||||
/// See below for usage.
|
||||
macro_rules! impl_matrix_mul {
|
||||
(<$($life:lifetime),*>($a_name:ident : $a:ty, $b_name:ident : $b:ty) -> $ret:ty $body:block)
|
||||
=>
|
||||
{
|
||||
impl<$($life,)* T> Mul<$b> for $a
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
type Output = $ret;
|
||||
fn mul(self, rhs: $b) -> Self::Output {
|
||||
let $a_name = self;
|
||||
let $b_name = rhs;
|
||||
$body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> {
|
||||
let pattern = spmm_pattern(a.pattern(), b.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
.unwrap();
|
||||
spmm_csr(&mut result,
|
||||
T::zero(),
|
||||
T::one(),
|
||||
Transpose(false),
|
||||
a,
|
||||
Transpose(false),
|
||||
b)
|
||||
.expect("Internal error: spmm failed (please debug).");
|
||||
result
|
||||
});
|
||||
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { a * &b});
|
||||
impl_matrix_mul!(<'a>(a: CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> { &a * b});
|
||||
impl_matrix_mul!(<>(a: CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { &a * &b});
|
@ -292,7 +292,7 @@ proptest! {
|
||||
(a, b)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40);
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), 40);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
@ -425,4 +425,42 @@ proptest! {
|
||||
prop_assert!(result.is_err(),
|
||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_mul_csr(
|
||||
// a and b have dimensions compatible for multiplication
|
||||
(a, b)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let max_nnz = PROPTEST_MAX_NNZ;
|
||||
let cols = PROPTEST_MATRIX_DIM;
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
// We use the dense result as the ground truth for the arithmetic result
|
||||
let c_dense = DMatrix::from(&a) * DMatrix::from(&b);
|
||||
// However, it's not enough only to cover the dense result, we also need to verify the
|
||||
// sparsity pattern. We can determine the exact sparsity pattern by using
|
||||
// dense arithmetic with positive integer values and extracting positive entries.
|
||||
let c_dense_pattern = dense_csr_pattern(a.pattern()) * dense_csr_pattern(b.pattern());
|
||||
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
|
||||
|
||||
// Check each combination of owned matrices and references
|
||||
let c_owned_owned = a.clone() * b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_owned_ref = a.clone() * &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_owned = &a * b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_ref = &a * &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user