Implement Sub for Csr/CscMatrix
This commit is contained in:
parent
7aeb663165
commit
0b4356eb0e
@ -1,10 +1,10 @@
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
|
||||
use std::ops::{Add, Mul, MulAssign};
|
||||
use std::ops::{Add, Mul, MulAssign, Sub, Neg};
|
||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
||||
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, Scalar};
|
||||
use num_traits::{Zero, One};
|
||||
use std::sync::Arc;
|
||||
use crate::ops::{Op};
|
||||
@ -18,7 +18,10 @@ macro_rules! impl_bin_op {
|
||||
{
|
||||
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
||||
where
|
||||
$($scalar_type: Scalar + ClosedAdd + ClosedMul + Zero + One)?
|
||||
// Note: The Signed bound is currently required because we delegate e.g.
|
||||
// Sub to SpAdd with negative coefficients. This is not well-defined for
|
||||
// unsigned data types.
|
||||
$($scalar_type: Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
|
||||
{
|
||||
type Output = $ret;
|
||||
fn $method(self, $b: $b_type) -> Self::Output {
|
||||
@ -29,17 +32,19 @@ macro_rules! impl_bin_op {
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_add {
|
||||
($($args:tt)*) => {
|
||||
impl_bin_op!(Add, add, $($args)*);
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements a + b for all combinations of reference and owned matrices, for
|
||||
/// Implements a +/- b for all combinations of reference and owned matrices, for
|
||||
/// CsrMatrix or CscMatrix.
|
||||
macro_rules! impl_spadd {
|
||||
($matrix_type:ident, $spadd_fn:ident) => {
|
||||
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
macro_rules! impl_sp_plus_minus {
|
||||
// We first match on some special-case syntax, and forward to the actual implementation
|
||||
($matrix_type:ident, $spadd_fn:ident, +) => {
|
||||
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
|
||||
};
|
||||
($matrix_type:ident, $spadd_fn:ident, -) => {
|
||||
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
|
||||
};
|
||||
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
// If both matrices have the same pattern, then we can immediately re-use it
|
||||
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||
Arc::clone(a.pattern())
|
||||
@ -51,31 +56,41 @@ macro_rules! impl_spadd {
|
||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||
.unwrap();
|
||||
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
|
||||
$spadd_fn(T::one(), &mut result, T::one(), Op::NoOp(&b)).unwrap();
|
||||
$spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
|
||||
result
|
||||
});
|
||||
|
||||
impl_add!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
let mut a = a;
|
||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap();
|
||||
$spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap();
|
||||
a
|
||||
} else {
|
||||
&a + b
|
||||
&a $sign b
|
||||
}
|
||||
});
|
||||
|
||||
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
b + a
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
let mut b = b;
|
||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||
$spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap();
|
||||
b
|
||||
} else {
|
||||
a $sign &b
|
||||
}
|
||||
});
|
||||
impl_add!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
a + &b
|
||||
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
a $sign &b
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl_spadd!(CsrMatrix, spadd_csr_prealloc);
|
||||
impl_spadd!(CscMatrix, spadd_csc_prealloc);
|
||||
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
|
||||
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
|
||||
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
|
||||
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
|
||||
|
||||
macro_rules! impl_mul {
|
||||
($($args:tt)*) => {
|
||||
@ -154,7 +169,7 @@ macro_rules! impl_scalar_mul {
|
||||
});
|
||||
impl_concrete_scalar_matrix_mul!(
|
||||
$matrix_type,
|
||||
i8, i16, i32, i64, u8, u16, u32, u64, isize, usize, f32, f64);
|
||||
i8, i16, i32, i64, isize, f32, f64);
|
||||
|
||||
impl<T> MulAssign<T> for $matrix_type<T>
|
||||
where
|
||||
@ -181,4 +196,6 @@ macro_rules! impl_scalar_mul {
|
||||
}
|
||||
|
||||
impl_scalar_mul!(CsrMatrix);
|
||||
impl_scalar_mul!(CscMatrix);
|
||||
impl_scalar_mul!(CscMatrix);
|
||||
|
||||
// TODO: Neg, Div
|
@ -417,6 +417,39 @@ proptest! {
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_sub_csr(
|
||||
// a and b have the same dimensions
|
||||
(a, b)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
// See comments in csr_add_csr for rationale for checking the pattern this way
|
||||
let c_dense = DMatrix::from(&a) - DMatrix::from(&b);
|
||||
let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern());
|
||||
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
|
||||
|
||||
// Check each combination of owned matrices and references
|
||||
let c_owned_owned = a.clone() - b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_owned_ref = a.clone() - &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_owned = &a - b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_ref = &a - &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
|
||||
{
|
||||
@ -837,6 +870,39 @@ proptest! {
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_sub_csc(
|
||||
// a and b have the same dimensions
|
||||
(a, b)
|
||||
in csc_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
// See comments in csc_add_csc for rationale for checking the pattern this way
|
||||
let c_dense = DMatrix::from(&a) - DMatrix::from(&b);
|
||||
let c_dense_pattern = dense_csc_pattern(a.pattern()) + dense_csc_pattern(b.pattern());
|
||||
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
|
||||
|
||||
// Check each combination of owned matrices and references
|
||||
let c_owned_owned = a.clone() - b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_owned_ref = a.clone() - &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_owned = &a - b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_ref = &a - &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) {
|
||||
let dense = DMatrix::from(&matrix);
|
||||
|
Loading…
Reference in New Issue
Block a user