diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index 9309ff93..d62519e9 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -4,7 +4,7 @@ use crate::csc::CscMatrix; use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg}; use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern}; use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim, - DMatrixSlice, DMatrixSliceMut, DMatrix, Dynamic, DefaultAllocator, U1}; + Dynamic, DefaultAllocator, U1}; use nalgebra::allocator::{Allocator}; use nalgebra::constraint::{DimEq, ShapeConstraint}; use num_traits::{Zero, One}; @@ -264,8 +264,34 @@ impl_div!(CsrMatrix); impl_div!(CscMatrix); macro_rules! impl_spmm_cs_dense { - ($matrix_type:ident, $spmm_fn:ident) => { - impl<'a, T, R, C, S> Mul<&'a Matrix> for &'a $matrix_type + ($matrix_type_name:ident, $spmm_fn:ident) => { + // Implement ref-ref + impl_spmm_cs_dense!(&'a $matrix_type_name, &'a Matrix, $spmm_fn, |lhs, rhs| { + let (_, ncols) = rhs.data.shape(); + let nrows = Dynamic::new(lhs.nrows()); + let mut result = MatrixMN::::zeros_generic(nrows, ncols); + $spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs)); + result + }); + + // Implement the other combinations by deferring to ref-ref + impl_spmm_cs_dense!(&'a $matrix_type_name, Matrix, $spmm_fn, |lhs, rhs| { + lhs * &rhs + }); + impl_spmm_cs_dense!($matrix_type_name, &'a Matrix, $spmm_fn, |lhs, rhs| { + &lhs * rhs + }); + impl_spmm_cs_dense!($matrix_type_name, Matrix, $spmm_fn, |lhs, rhs| { + &lhs * &rhs + }); + }; + + // Main body of the macro. The first pattern just forwards to this pattern but with + // different arguments + ($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident, + |$lhs:ident, $rhs:ident| $body:tt) => + { + impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type where T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One, R: Dim, @@ -287,16 +313,10 @@ macro_rules! impl_spmm_cs_dense { // we also get a vector (and not a matrix) type Output = MatrixMN; - fn mul(self, rhs: &'a Matrix) -> Self::Output { - // let rhs = rhs.into(); - let (_, ncols) = rhs.data.shape(); - let nrows = Dynamic::new(self.nrows()); - let mut result = MatrixMN::::zeros_generic(nrows, ncols); - { - // let result: DMatrixSliceMut<_> = (&mut result).into(); - $spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs)); - } - result + fn mul(self, rhs: $dense_matrix_type) -> Self::Output { + let $lhs = self; + let $rhs = rhs; + $body } } } diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 0bc6b42a..4e789eb3 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1123,7 +1123,11 @@ proptest! { (Just(a), b) })) { - prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b); + let expected = DMatrix::from(&a) * &b; + prop_assert_eq!(&a * &b, expected.clone()); + prop_assert_eq!(&a * b.clone(), expected.clone()); + prop_assert_eq!(a.clone() * &b, expected.clone()); + prop_assert_eq!(a.clone() * b.clone(), expected.clone()); } #[test] @@ -1137,7 +1141,11 @@ proptest! { (Just(a), b) })) { - prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b); + let expected = DMatrix::from(&a) * &b; + prop_assert_eq!(&a * &b, expected.clone()); + prop_assert_eq!(&a * b.clone(), expected.clone()); + prop_assert_eq!(a.clone() * &b, expected.clone()); + prop_assert_eq!(a.clone() * b.clone(), expected.clone()); } #[test]