Extend CSC/CSR * Dense to work for combinations of ref and owned
This commit is contained in:
parent
74cd0283eb
commit
0bee9be6c7
|
@ -4,7 +4,7 @@ use crate::csc::CscMatrix;
|
||||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim,
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, MatrixMN, Dim,
|
||||||
DMatrixSlice, DMatrixSliceMut, DMatrix, Dynamic, DefaultAllocator, U1};
|
Dynamic, DefaultAllocator, U1};
|
||||||
use nalgebra::allocator::{Allocator};
|
use nalgebra::allocator::{Allocator};
|
||||||
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
|
@ -264,8 +264,34 @@ impl_div!(CsrMatrix);
|
||||||
impl_div!(CscMatrix);
|
impl_div!(CscMatrix);
|
||||||
|
|
||||||
macro_rules! impl_spmm_cs_dense {
|
macro_rules! impl_spmm_cs_dense {
|
||||||
($matrix_type:ident, $spmm_fn:ident) => {
|
($matrix_type_name:ident, $spmm_fn:ident) => {
|
||||||
impl<'a, T, R, C, S> Mul<&'a Matrix<T, R, C, S>> for &'a $matrix_type<T>
|
// Implement ref-ref
|
||||||
|
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
let (_, ncols) = rhs.data.shape();
|
||||||
|
let nrows = Dynamic::new(lhs.nrows());
|
||||||
|
let mut result = MatrixMN::<T, Dynamic, C>::zeros_generic(nrows, ncols);
|
||||||
|
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
|
||||||
|
result
|
||||||
|
});
|
||||||
|
|
||||||
|
// Implement the other combinations by deferring to ref-ref
|
||||||
|
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
lhs * &rhs
|
||||||
|
});
|
||||||
|
impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
&lhs * rhs
|
||||||
|
});
|
||||||
|
impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||||
|
&lhs * &rhs
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Main body of the macro. The first pattern just forwards to this pattern but with
|
||||||
|
// different arguments
|
||||||
|
($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
|
||||||
|
|$lhs:ident, $rhs:ident| $body:tt) =>
|
||||||
|
{
|
||||||
|
impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
||||||
R: Dim,
|
R: Dim,
|
||||||
|
@ -287,16 +313,10 @@ macro_rules! impl_spmm_cs_dense {
|
||||||
// we also get a vector (and not a matrix)
|
// we also get a vector (and not a matrix)
|
||||||
type Output = MatrixMN<T, Dynamic, C>;
|
type Output = MatrixMN<T, Dynamic, C>;
|
||||||
|
|
||||||
fn mul(self, rhs: &'a Matrix<T, R, C, S>) -> Self::Output {
|
fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
|
||||||
// let rhs = rhs.into();
|
let $lhs = self;
|
||||||
let (_, ncols) = rhs.data.shape();
|
let $rhs = rhs;
|
||||||
let nrows = Dynamic::new(self.nrows());
|
$body
|
||||||
let mut result = MatrixMN::<T, Dynamic, C>::zeros_generic(nrows, ncols);
|
|
||||||
{
|
|
||||||
// let result: DMatrixSliceMut<_> = (&mut result).into();
|
|
||||||
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs));
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1123,7 +1123,11 @@ proptest! {
|
||||||
(Just(a), b)
|
(Just(a), b)
|
||||||
}))
|
}))
|
||||||
{
|
{
|
||||||
prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b);
|
let expected = DMatrix::from(&a) * &b;
|
||||||
|
prop_assert_eq!(&a * &b, expected.clone());
|
||||||
|
prop_assert_eq!(&a * b.clone(), expected.clone());
|
||||||
|
prop_assert_eq!(a.clone() * &b, expected.clone());
|
||||||
|
prop_assert_eq!(a.clone() * b.clone(), expected.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1137,7 +1141,11 @@ proptest! {
|
||||||
(Just(a), b)
|
(Just(a), b)
|
||||||
}))
|
}))
|
||||||
{
|
{
|
||||||
prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b);
|
let expected = DMatrix::from(&a) * &b;
|
||||||
|
prop_assert_eq!(&a * &b, expected.clone());
|
||||||
|
prop_assert_eq!(&a * b.clone(), expected.clone());
|
||||||
|
prop_assert_eq!(a.clone() * &b, expected.clone());
|
||||||
|
prop_assert_eq!(a.clone() * b.clone(), expected.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
Loading…
Reference in New Issue