diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index b9ee23a4..12db5264 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -2,12 +2,14 @@ use crate::csr::CsrMatrix; use crate::csc::CscMatrix; use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg}; -use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, - spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc}; -use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar}; +use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern, + spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense}; +use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim, + DMatrixSlice, DMatrix, Dynamic}; use num_traits::{Zero, One}; use std::sync::Arc; use crate::ops::{Op}; +use nalgebra::base::storage::Storage; /// Helper macro for implementing binary operators for different matrix types /// See below for usage. @@ -275,4 +277,31 @@ macro_rules! impl_div { } impl_div!(CsrMatrix); -impl_div!(CscMatrix); \ No newline at end of file +impl_div!(CscMatrix); + +macro_rules! impl_spmm_cs_dense { + ($matrix_type:ident, $spmm_fn:ident) => { + impl<'a, T, R, C, S> Mul<&'a Matrix> for &'a $matrix_type + where + &'a Matrix: Into>, + T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One, + R: Dim, + C: Dim, + S: Storage, + { + type Output = DMatrix; + + fn mul(self, rhs: &'a Matrix) -> Self::Output { + let rhs = rhs.into(); + let (_, ncols) = rhs.data.shape(); + let nrows = Dynamic::new(self.nrows()); + let mut result = Matrix::zeros_generic(nrows, ncols); + $spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs)); + result + } + } + } +} + +impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense); +impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense); \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 001945a9..a8dc248b 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1087,4 +1087,32 @@ proptest! { prop_assert_eq!(&result_ref, &expected_result); } + #[test] + fn csr_mul_dense( + // a and b have dimensions compatible for multiplication + (a, b) + in csr_strategy() + .prop_flat_map(|a| { + let cols = PROPTEST_MATRIX_DIM; + let b = matrix(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols); + (Just(a), b) + })) + { + prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b); + } + + #[test] + fn csc_mul_dense( + // a and b have dimensions compatible for multiplication + (a, b) + in csc_strategy() + .prop_flat_map(|a| { + let cols = PROPTEST_MATRIX_DIM; + let b = matrix(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols); + (Just(a), b) + })) + { + prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b); + } + } \ No newline at end of file