forked from M-Labs/nalgebra
Preserve column dim type in CSR * Dense
This is necessary so that CSR * Vector == Vector (before it would also yield a DMatrix).
This commit is contained in:
parent
15c4382fa9
commit
1fa0de92ae
@ -3,8 +3,10 @@ use crate::csc::CscMatrix;
|
|||||||
|
|
||||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim,
|
||||||
DMatrixSlice, DMatrix, Dynamic};
|
DMatrixSlice, DMatrixSliceMut, DMatrix, Dynamic, DefaultAllocator, U1};
|
||||||
|
use nalgebra::allocator::{Allocator};
|
||||||
|
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
use nalgebra::base::storage::Storage;
|
use nalgebra::base::storage::Storage;
|
||||||
@ -265,20 +267,35 @@ macro_rules! impl_spmm_cs_dense {
|
|||||||
($matrix_type:ident, $spmm_fn:ident) => {
|
($matrix_type:ident, $spmm_fn:ident) => {
|
||||||
impl<'a, T, R, C, S> Mul<&'a Matrix<T, R, C, S>> for &'a $matrix_type<T>
|
impl<'a, T, R, C, S> Mul<&'a Matrix<T, R, C, S>> for &'a $matrix_type<T>
|
||||||
where
|
where
|
||||||
&'a Matrix<T, R, C, S>: Into<DMatrixSlice<'a, T>>,
|
|
||||||
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
C: Dim,
|
C: Dim,
|
||||||
S: Storage<T, R, C>,
|
S: Storage<T, R, C>,
|
||||||
|
DefaultAllocator: Allocator<T, Dynamic, C>,
|
||||||
|
// TODO: Is it possible to simplify these bounds?
|
||||||
|
ShapeConstraint:
|
||||||
|
// Bounds so that we can turn MatrixMN<T, Dynamic, C> into a DMatrixSliceMut
|
||||||
|
DimEq<U1, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::RStride>
|
||||||
|
+ DimEq<C, Dynamic>
|
||||||
|
+ DimEq<Dynamic, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::CStride>
|
||||||
|
// Bounds so that we can turn &Matrix<T, R, C, S> into a DMatrixSlice
|
||||||
|
+ DimEq<U1, S::RStride>
|
||||||
|
+ DimEq<R, Dynamic>
|
||||||
|
+ DimEq<Dynamic, S::CStride>
|
||||||
{
|
{
|
||||||
type Output = DMatrix<T>;
|
// We need the column dimension to be generic, so that if RHS is a vector, then
|
||||||
|
// we also get a vector (and not a matrix)
|
||||||
|
type Output = MatrixMN<T, Dynamic, C>;
|
||||||
|
|
||||||
fn mul(self, rhs: &'a Matrix<T, R, C, S>) -> Self::Output {
|
fn mul(self, rhs: &'a Matrix<T, R, C, S>) -> Self::Output {
|
||||||
let rhs = rhs.into();
|
// let rhs = rhs.into();
|
||||||
let (_, ncols) = rhs.data.shape();
|
let (_, ncols) = rhs.data.shape();
|
||||||
let nrows = Dynamic::new(self.nrows());
|
let nrows = Dynamic::new(self.nrows());
|
||||||
let mut result = Matrix::zeros_generic(nrows, ncols);
|
let mut result = MatrixMN::<T, Dynamic, C>::zeros_generic(nrows, ncols);
|
||||||
|
{
|
||||||
|
// let result: DMatrixSliceMut<_> = (&mut result).into();
|
||||||
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs));
|
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs));
|
||||||
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user