Implement Sub for Csr/CscMatrix

This commit is contained in:
Andreas Longva 2021-01-05 14:59:54 +01:00
parent 7aeb663165
commit 0b4356eb0e
2 changed files with 108 additions and 25 deletions

View File

@ -1,10 +1,10 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use std::ops::{Add, Mul, MulAssign}; use std::ops::{Add, Mul, MulAssign, Sub, Neg};
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc}; spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, Scalar};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
use std::sync::Arc; use std::sync::Arc;
use crate::ops::{Op}; use crate::ops::{Op};
@ -18,7 +18,10 @@ macro_rules! impl_bin_op {
{ {
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
where where
$($scalar_type: Scalar + ClosedAdd + ClosedMul + Zero + One)? // Note: The Signed bound is currently required because we delegate e.g.
// Sub to SpAdd with negative coefficients. This is not well-defined for
// unsigned data types.
$($scalar_type: Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
{ {
type Output = $ret; type Output = $ret;
fn $method(self, $b: $b_type) -> Self::Output { fn $method(self, $b: $b_type) -> Self::Output {
@ -29,17 +32,19 @@ macro_rules! impl_bin_op {
} }
} }
macro_rules! impl_add { /// Implements a +/- b for all combinations of reference and owned matrices, for
($($args:tt)*) => {
impl_bin_op!(Add, add, $($args)*);
}
}
/// Implements a + b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix. /// CsrMatrix or CscMatrix.
macro_rules! impl_spadd { macro_rules! impl_sp_plus_minus {
($matrix_type:ident, $spadd_fn:ident) => { // We first match on some special-case syntax, and forward to the actual implementation
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { ($matrix_type:ident, $spadd_fn:ident, +) => {
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
};
($matrix_type:ident, $spadd_fn:ident, -) => {
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
};
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
// If both matrices have the same pattern, then we can immediately re-use it // If both matrices have the same pattern, then we can immediately re-use it
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) { let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
Arc::clone(a.pattern()) Arc::clone(a.pattern())
@ -51,31 +56,41 @@ macro_rules! impl_spadd {
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values) let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap(); .unwrap();
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap(); $spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
$spadd_fn(T::one(), &mut result, T::one(), Op::NoOp(&b)).unwrap(); $spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
result result
}); });
impl_add!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { impl_bin_op!($trait, $method,
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let mut a = a; let mut a = a;
if Arc::ptr_eq(a.pattern(), b.pattern()) { if Arc::ptr_eq(a.pattern(), b.pattern()) {
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap(); $spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap();
a a
} else { } else {
&a + b &a $sign b
} }
}); });
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { impl_bin_op!($trait, $method,
b + a <'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
let mut b = b;
if Arc::ptr_eq(a.pattern(), b.pattern()) {
$spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap();
b
} else {
a $sign &b
}
}); });
impl_add!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a + &b a $sign &b
}); });
} }
} }
impl_spadd!(CsrMatrix, spadd_csr_prealloc); impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
impl_spadd!(CscMatrix, spadd_csc_prealloc); impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
macro_rules! impl_mul { macro_rules! impl_mul {
($($args:tt)*) => { ($($args:tt)*) => {
@ -154,7 +169,7 @@ macro_rules! impl_scalar_mul {
}); });
impl_concrete_scalar_matrix_mul!( impl_concrete_scalar_matrix_mul!(
$matrix_type, $matrix_type,
i8, i16, i32, i64, u8, u16, u32, u64, isize, usize, f32, f64); i8, i16, i32, i64, isize, f32, f64);
impl<T> MulAssign<T> for $matrix_type<T> impl<T> MulAssign<T> for $matrix_type<T>
where where
@ -182,3 +197,5 @@ macro_rules! impl_scalar_mul {
impl_scalar_mul!(CsrMatrix); impl_scalar_mul!(CsrMatrix);
impl_scalar_mul!(CscMatrix); impl_scalar_mul!(CscMatrix);
// TODO: Neg, Div

View File

@ -417,6 +417,39 @@ proptest! {
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
} }
#[test]
fn csr_sub_csr(
// a and b have the same dimensions
(a, b)
in csr_strategy()
.prop_flat_map(|a| {
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{
// See comments in csr_add_csr for rationale for checking the pattern this way
let c_dense = DMatrix::from(&a) - DMatrix::from(&b);
let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern());
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
// Check each combination of owned matrices and references
let c_owned_owned = a.clone() - b.clone();
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
let c_owned_ref = a.clone() - &b;
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
let c_ref_owned = &a - b.clone();
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
let c_ref_ref = &a - &b;
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
#[test] #[test]
fn spmm_pattern_test((a, b) in spmm_pattern_strategy()) fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
{ {
@ -837,6 +870,39 @@ proptest! {
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
} }
#[test]
fn csc_sub_csc(
// a and b have the same dimensions
(a, b)
in csc_strategy()
.prop_flat_map(|a| {
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{
// See comments in csc_add_csc for rationale for checking the pattern this way
let c_dense = DMatrix::from(&a) - DMatrix::from(&b);
let c_dense_pattern = dense_csc_pattern(a.pattern()) + dense_csc_pattern(b.pattern());
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
// Check each combination of owned matrices and references
let c_owned_owned = a.clone() - b.clone();
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
let c_owned_ref = a.clone() - &b;
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
let c_ref_owned = &a - b.clone();
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
let c_ref_ref = &a - &b;
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
#[test] #[test]
fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) { fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) {
let dense = DMatrix::from(&matrix); let dense = DMatrix::from(&matrix);