diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index b8841bf2..9309ff93 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -3,8 +3,10 @@ use crate::csc::CscMatrix; use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg}; use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern}; -use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim, - DMatrixSlice, DMatrix, Dynamic}; +use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim, + DMatrixSlice, DMatrixSliceMut, DMatrix, Dynamic, DefaultAllocator, U1}; +use nalgebra::allocator::{Allocator}; +use nalgebra::constraint::{DimEq, ShapeConstraint}; use num_traits::{Zero, One}; use crate::ops::{Op}; use nalgebra::base::storage::Storage; @@ -265,20 +267,35 @@ macro_rules! impl_spmm_cs_dense { ($matrix_type:ident, $spmm_fn:ident) => { impl<'a, T, R, C, S> Mul<&'a Matrix> for &'a $matrix_type where - &'a Matrix: Into>, T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One, R: Dim, C: Dim, S: Storage, + DefaultAllocator: Allocator, + // TODO: Is it possible to simplify these bounds? + ShapeConstraint: + // Bounds so that we can turn MatrixMN into a DMatrixSliceMut + DimEq>::Buffer as Storage>::RStride> + + DimEq + + DimEq>::Buffer as Storage>::CStride> + // Bounds so that we can turn &Matrix into a DMatrixSlice + + DimEq + + DimEq + + DimEq { - type Output = DMatrix; + // We need the column dimension to be generic, so that if RHS is a vector, then + // we also get a vector (and not a matrix) + type Output = MatrixMN; fn mul(self, rhs: &'a Matrix) -> Self::Output { - let rhs = rhs.into(); + // let rhs = rhs.into(); let (_, ncols) = rhs.data.shape(); let nrows = Dynamic::new(self.nrows()); - let mut result = Matrix::zeros_generic(nrows, ncols); - $spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs)); + let mut result = MatrixMN::::zeros_generic(nrows, ncols); + { + // let result: DMatrixSliceMut<_> = (&mut result).into(); + $spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs)); + } result } }