forked from M-Labs/nalgebra
Implement CSR-CSR matrix multiplication
This commit is contained in:
parent
d9cfe5cb3e
commit
b25848838b
@ -1,7 +1,7 @@
|
|||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
use std::ops::Add;
|
use std::ops::{Add, Mul};
|
||||||
use crate::ops::serial::{spadd_csr, spadd_build_pattern};
|
use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -66,3 +66,42 @@ where
|
|||||||
self + &rhs
|
self + &rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper macro for implementing matrix multiplication for different matrix types
|
||||||
|
/// See below for usage.
|
||||||
|
macro_rules! impl_matrix_mul {
|
||||||
|
(<$($life:lifetime),*>($a_name:ident : $a:ty, $b_name:ident : $b:ty) -> $ret:ty $body:block)
|
||||||
|
=>
|
||||||
|
{
|
||||||
|
impl<$($life,)* T> Mul<$b> for $a
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
type Output = $ret;
|
||||||
|
fn mul(self, rhs: $b) -> Self::Output {
|
||||||
|
let $a_name = self;
|
||||||
|
let $b_name = rhs;
|
||||||
|
$body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> {
|
||||||
|
let pattern = spmm_pattern(a.pattern(), b.pattern());
|
||||||
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
|
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
|
.unwrap();
|
||||||
|
spmm_csr(&mut result,
|
||||||
|
T::zero(),
|
||||||
|
T::one(),
|
||||||
|
Transpose(false),
|
||||||
|
a,
|
||||||
|
Transpose(false),
|
||||||
|
b)
|
||||||
|
.expect("Internal error: spmm failed (please debug).");
|
||||||
|
result
|
||||||
|
});
|
||||||
|
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { a * &b});
|
||||||
|
impl_matrix_mul!(<'a>(a: CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> { &a * b});
|
||||||
|
impl_matrix_mul!(<>(a: CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { &a * &b});
|
@ -292,7 +292,7 @@ proptest! {
|
|||||||
(a, b)
|
(a, b)
|
||||||
in csr_strategy()
|
in csr_strategy()
|
||||||
.prop_flat_map(|a| {
|
.prop_flat_map(|a| {
|
||||||
let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40);
|
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), 40);
|
||||||
(Just(a), b)
|
(Just(a), b)
|
||||||
}))
|
}))
|
||||||
{
|
{
|
||||||
@ -425,4 +425,42 @@ proptest! {
|
|||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_mul_csr(
|
||||||
|
// a and b have dimensions compatible for multiplication
|
||||||
|
(a, b)
|
||||||
|
in csr_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let max_nnz = PROPTEST_MAX_NNZ;
|
||||||
|
let cols = PROPTEST_MATRIX_DIM;
|
||||||
|
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// We use the dense result as the ground truth for the arithmetic result
|
||||||
|
let c_dense = DMatrix::from(&a) * DMatrix::from(&b);
|
||||||
|
// However, it's not enough only to cover the dense result, we also need to verify the
|
||||||
|
// sparsity pattern. We can determine the exact sparsity pattern by using
|
||||||
|
// dense arithmetic with positive integer values and extracting positive entries.
|
||||||
|
let c_dense_pattern = dense_csr_pattern(a.pattern()) * dense_csr_pattern(b.pattern());
|
||||||
|
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
|
||||||
|
|
||||||
|
// Check each combination of owned matrices and references
|
||||||
|
let c_owned_owned = a.clone() * b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_owned_ref = a.clone() * &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_owned = &a * b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_ref = &a * &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user