diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index c8e9e800..17f357a6 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -1,7 +1,7 @@ use crate::csr::CsrMatrix; -use std::ops::Add; -use crate::ops::serial::{spadd_csr, spadd_build_pattern}; +use std::ops::{Add, Mul}; +use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr}; use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use num_traits::{Zero, One}; use std::sync::Arc; @@ -65,4 +65,43 @@ where fn add(self, rhs: CsrMatrix) -> Self::Output { self + &rhs } -} \ No newline at end of file +} + +/// Helper macro for implementing matrix multiplication for different matrix types +/// See below for usage. +macro_rules! impl_matrix_mul { + (<$($life:lifetime),*>($a_name:ident : $a:ty, $b_name:ident : $b:ty) -> $ret:ty $body:block) + => + { + impl<$($life,)* T> Mul<$b> for $a + where + T: Scalar + ClosedAdd + ClosedMul + Zero + One + { + type Output = $ret; + fn mul(self, rhs: $b) -> Self::Output { + let $a_name = self; + let $b_name = rhs; + $body + } + } + } +} + +impl_matrix_mul!(<'a>(a: &'a CsrMatrix, b: &'a CsrMatrix) -> CsrMatrix { + let pattern = spmm_pattern(a.pattern(), b.pattern()); + let values = vec![T::zero(); pattern.nnz()]; + let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) + .unwrap(); + spmm_csr(&mut result, + T::zero(), + T::one(), + Transpose(false), + a, + Transpose(false), + b) + .expect("Internal error: spmm failed (please debug)."); + result +}); +impl_matrix_mul!(<'a>(a: &'a CsrMatrix, b: CsrMatrix) -> CsrMatrix { a * &b}); +impl_matrix_mul!(<'a>(a: CsrMatrix, b: &'a CsrMatrix) -> CsrMatrix { &a * b}); +impl_matrix_mul!(<>(a: CsrMatrix, b: CsrMatrix) -> CsrMatrix { &a * &b}); \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index c1df496f..d5f2bba8 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -292,7 +292,7 @@ proptest! { (a, b) in csr_strategy() .prop_flat_map(|a| { - let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40); + let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), 40); (Just(a), b) })) { @@ -425,4 +425,42 @@ proptest! { prop_assert!(result.is_err(), "The SPMM kernel executed successfully despite mismatch dimensions"); } + + #[test] + fn csr_mul_csr( + // a and b have dimensions compatible for multiplication + (a, b) + in csr_strategy() + .prop_flat_map(|a| { + let max_nnz = PROPTEST_MAX_NNZ; + let cols = PROPTEST_MATRIX_DIM; + let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz); + (Just(a), b) + })) + { + // We use the dense result as the ground truth for the arithmetic result + let c_dense = DMatrix::from(&a) * DMatrix::from(&b); + // However, it's not enough only to cover the dense result, we also need to verify the + // sparsity pattern. We can determine the exact sparsity pattern by using + // dense arithmetic with positive integer values and extracting positive entries. + let c_dense_pattern = dense_csr_pattern(a.pattern()) * dense_csr_pattern(b.pattern()); + let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone(); + + // Check each combination of owned matrices and references + let c_owned_owned = a.clone() * b.clone(); + prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense); + prop_assert_eq!(c_owned_owned.pattern(), &c_pattern); + + let c_owned_ref = a.clone() * &b; + prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense); + prop_assert_eq!(c_owned_ref.pattern(), &c_pattern); + + let c_ref_owned = &a * b.clone(); + prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense); + prop_assert_eq!(c_ref_owned.pattern(), &c_pattern); + + let c_ref_ref = &a * &b; + prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); + prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); + } } \ No newline at end of file