Implement CSR/CSC * Dense std operations
This commit is contained in:
parent
b7a7f967b8
commit
885480a634
|
@ -2,12 +2,14 @@ use crate::csr::CsrMatrix;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern,
|
||||||
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
||||||
|
DMatrixSlice, DMatrix, Dynamic};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
|
use nalgebra::base::storage::Storage;
|
||||||
|
|
||||||
/// Helper macro for implementing binary operators for different matrix types
|
/// Helper macro for implementing binary operators for different matrix types
|
||||||
/// See below for usage.
|
/// See below for usage.
|
||||||
|
@ -276,3 +278,30 @@ macro_rules! impl_div {
|
||||||
|
|
||||||
impl_div!(CsrMatrix);
|
impl_div!(CsrMatrix);
|
||||||
impl_div!(CscMatrix);
|
impl_div!(CscMatrix);
|
||||||
|
|
||||||
|
macro_rules! impl_spmm_cs_dense {
|
||||||
|
($matrix_type:ident, $spmm_fn:ident) => {
|
||||||
|
impl<'a, T, R, C, S> Mul<&'a Matrix<T, R, C, S>> for &'a $matrix_type<T>
|
||||||
|
where
|
||||||
|
&'a Matrix<T, R, C, S>: Into<DMatrixSlice<'a, T>>,
|
||||||
|
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
||||||
|
R: Dim,
|
||||||
|
C: Dim,
|
||||||
|
S: Storage<T, R, C>,
|
||||||
|
{
|
||||||
|
type Output = DMatrix<T>;
|
||||||
|
|
||||||
|
fn mul(self, rhs: &'a Matrix<T, R, C, S>) -> Self::Output {
|
||||||
|
let rhs = rhs.into();
|
||||||
|
let (_, ncols) = rhs.data.shape();
|
||||||
|
let nrows = Dynamic::new(self.nrows());
|
||||||
|
let mut result = Matrix::zeros_generic(nrows, ncols);
|
||||||
|
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(self), Op::NoOp(rhs));
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
|
||||||
|
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);
|
|
@ -1087,4 +1087,32 @@ proptest! {
|
||||||
prop_assert_eq!(&result_ref, &expected_result);
|
prop_assert_eq!(&result_ref, &expected_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_mul_dense(
|
||||||
|
// a and b have dimensions compatible for multiplication
|
||||||
|
(a, b)
|
||||||
|
in csr_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let cols = PROPTEST_MATRIX_DIM;
|
||||||
|
let b = matrix(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_mul_dense(
|
||||||
|
// a and b have dimensions compatible for multiplication
|
||||||
|
(a, b)
|
||||||
|
in csc_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let cols = PROPTEST_MATRIX_DIM;
|
||||||
|
let b = matrix(PROPTEST_I32_VALUE_STRATEGY, a.ncols(), cols);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
prop_assert_eq!(&a * &b, &DMatrix::from(&a) * &b);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue