Implement CSR-CSR matrix multiplication

This commit is contained in:
Andreas Longva 2020-12-16 16:17:42 +01:00
parent d9cfe5cb3e
commit b25848838b
2 changed files with 81 additions and 4 deletions

View File

@ -1,7 +1,7 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use std::ops::Add; use std::ops::{Add, Mul};
use crate::ops::serial::{spadd_csr, spadd_build_pattern}; use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr};
use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use nalgebra::{ClosedAdd, ClosedMul, Scalar};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
use std::sync::Arc; use std::sync::Arc;
@ -65,4 +65,43 @@ where
fn add(self, rhs: CsrMatrix<T>) -> Self::Output { fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
self + &rhs self + &rhs
} }
} }
/// Helper macro for implementing matrix multiplication for different matrix types
/// See below for usage.
macro_rules! impl_matrix_mul {
(<$($life:lifetime),*>($a_name:ident : $a:ty, $b_name:ident : $b:ty) -> $ret:ty $body:block)
=>
{
impl<$($life,)* T> Mul<$b> for $a
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = $ret;
fn mul(self, rhs: $b) -> Self::Output {
let $a_name = self;
let $b_name = rhs;
$body
}
}
}
}
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> {
let pattern = spmm_pattern(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
.unwrap();
spmm_csr(&mut result,
T::zero(),
T::one(),
Transpose(false),
a,
Transpose(false),
b)
.expect("Internal error: spmm failed (please debug).");
result
});
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { a * &b});
impl_matrix_mul!(<'a>(a: CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> { &a * b});
impl_matrix_mul!(<>(a: CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { &a * &b});

View File

@ -292,7 +292,7 @@ proptest! {
(a, b) (a, b)
in csr_strategy() in csr_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40); let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), 40);
(Just(a), b) (Just(a), b)
})) }))
{ {
@ -425,4 +425,42 @@ proptest! {
prop_assert!(result.is_err(), prop_assert!(result.is_err(),
"The SPMM kernel executed successfully despite mismatch dimensions"); "The SPMM kernel executed successfully despite mismatch dimensions");
} }
#[test]
fn csr_mul_csr(
// a and b have dimensions compatible for multiplication
(a, b)
in csr_strategy()
.prop_flat_map(|a| {
let max_nnz = PROPTEST_MAX_NNZ;
let cols = PROPTEST_MATRIX_DIM;
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
(Just(a), b)
}))
{
// We use the dense result as the ground truth for the arithmetic result
let c_dense = DMatrix::from(&a) * DMatrix::from(&b);
// However, it's not enough only to cover the dense result, we also need to verify the
// sparsity pattern. We can determine the exact sparsity pattern by using
// dense arithmetic with positive integer values and extracting positive entries.
let c_dense_pattern = dense_csr_pattern(a.pattern()) * dense_csr_pattern(b.pattern());
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
// Check each combination of owned matrices and references
let c_owned_owned = a.clone() * b.clone();
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
let c_owned_ref = a.clone() * &b;
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
let c_ref_owned = &a * b.clone();
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
let c_ref_ref = &a * &b;
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
} }