diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index f20dbd52..51b2dd8b 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -1,10 +1,10 @@ use crate::csr::CsrMatrix; use crate::csc::CscMatrix; -use std::ops::{Add, Mul, MulAssign}; +use std::ops::{Add, Mul, MulAssign, Sub, Neg}; use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc}; -use nalgebra::{ClosedAdd, ClosedMul, Scalar}; +use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, Scalar}; use num_traits::{Zero, One}; use std::sync::Arc; use crate::ops::{Op}; @@ -18,7 +18,10 @@ macro_rules! impl_bin_op { { impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type where - $($scalar_type: Scalar + ClosedAdd + ClosedMul + Zero + One)? + // Note: The Signed bound is currently required because we delegate e.g. + // Sub to SpAdd with negative coefficients. This is not well-defined for + // unsigned data types. + $($scalar_type: Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg)? { type Output = $ret; fn $method(self, $b: $b_type) -> Self::Output { @@ -29,17 +32,19 @@ macro_rules! impl_bin_op { } } -macro_rules! impl_add { - ($($args:tt)*) => { - impl_bin_op!(Add, add, $($args)*); - } -} - -/// Implements a + b for all combinations of reference and owned matrices, for +/// Implements a +/- b for all combinations of reference and owned matrices, for /// CsrMatrix or CscMatrix. -macro_rules! impl_spadd { - ($matrix_type:ident, $spadd_fn:ident) => { - impl_add!(<'a, T>(a: &'a $matrix_type, b: &'a $matrix_type) -> $matrix_type { +macro_rules! impl_sp_plus_minus { + // We first match on some special-case syntax, and forward to the actual implementation + ($matrix_type:ident, $spadd_fn:ident, +) => { + impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one()); + }; + ($matrix_type:ident, $spadd_fn:ident, -) => { + impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one()); + }; + ($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => { + impl_bin_op!($trait, $method, + <'a, T>(a: &'a $matrix_type, b: &'a $matrix_type) -> $matrix_type { // If both matrices have the same pattern, then we can immediately re-use it let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) { Arc::clone(a.pattern()) @@ -51,31 +56,41 @@ macro_rules! impl_spadd { let mut result = $matrix_type::try_from_pattern_and_values(pattern, values) .unwrap(); $spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap(); - $spadd_fn(T::one(), &mut result, T::one(), Op::NoOp(&b)).unwrap(); + $spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap(); result }); - impl_add!(<'a, T>(a: $matrix_type, b: &'a $matrix_type) -> $matrix_type { + impl_bin_op!($trait, $method, + <'a, T>(a: $matrix_type, b: &'a $matrix_type) -> $matrix_type { let mut a = a; if Arc::ptr_eq(a.pattern(), b.pattern()) { - $spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap(); + $spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap(); a } else { - &a + b + &a $sign b } }); - impl_add!(<'a, T>(a: &'a $matrix_type, b: $matrix_type) -> $matrix_type { - b + a + impl_bin_op!($trait, $method, + <'a, T>(a: &'a $matrix_type, b: $matrix_type) -> $matrix_type { + let mut b = b; + if Arc::ptr_eq(a.pattern(), b.pattern()) { + $spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap(); + b + } else { + a $sign &b + } }); - impl_add!((a: $matrix_type, b: $matrix_type) -> $matrix_type { - a + &b + impl_bin_op!($trait, $method, (a: $matrix_type, b: $matrix_type) -> $matrix_type { + a $sign &b }); } } -impl_spadd!(CsrMatrix, spadd_csr_prealloc); -impl_spadd!(CscMatrix, spadd_csc_prealloc); +impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +); +impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -); +impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +); +impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -); macro_rules! impl_mul { ($($args:tt)*) => { @@ -154,7 +169,7 @@ macro_rules! impl_scalar_mul { }); impl_concrete_scalar_matrix_mul!( $matrix_type, - i8, i16, i32, i64, u8, u16, u32, u64, isize, usize, f32, f64); + i8, i16, i32, i64, isize, f32, f64); impl MulAssign for $matrix_type where @@ -181,4 +196,6 @@ macro_rules! impl_scalar_mul { } impl_scalar_mul!(CsrMatrix); -impl_scalar_mul!(CscMatrix); \ No newline at end of file +impl_scalar_mul!(CscMatrix); + +// TODO: Neg, Div \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 3cbee92f..76817f61 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -417,6 +417,39 @@ proptest! { prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); } + #[test] + fn csr_sub_csr( + // a and b have the same dimensions + (a, b) + in csr_strategy() + .prop_flat_map(|a| { + let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); + (Just(a), b) + })) + { + // See comments in csr_add_csr for rationale for checking the pattern this way + let c_dense = DMatrix::from(&a) - DMatrix::from(&b); + let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern()); + let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone(); + + // Check each combination of owned matrices and references + let c_owned_owned = a.clone() - b.clone(); + prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense); + prop_assert_eq!(c_owned_owned.pattern(), &c_pattern); + + let c_owned_ref = a.clone() - &b; + prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense); + prop_assert_eq!(c_owned_ref.pattern(), &c_pattern); + + let c_ref_owned = &a - b.clone(); + prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense); + prop_assert_eq!(c_ref_owned.pattern(), &c_pattern); + + let c_ref_ref = &a - &b; + prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); + prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); + } + #[test] fn spmm_pattern_test((a, b) in spmm_pattern_strategy()) { @@ -837,6 +870,39 @@ proptest! { prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); } + #[test] + fn csc_sub_csc( + // a and b have the same dimensions + (a, b) + in csc_strategy() + .prop_flat_map(|a| { + let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ); + (Just(a), b) + })) + { + // See comments in csc_add_csc for rationale for checking the pattern this way + let c_dense = DMatrix::from(&a) - DMatrix::from(&b); + let c_dense_pattern = dense_csc_pattern(a.pattern()) + dense_csc_pattern(b.pattern()); + let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone(); + + // Check each combination of owned matrices and references + let c_owned_owned = a.clone() - b.clone(); + prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense); + prop_assert_eq!(c_owned_owned.pattern(), &c_pattern); + + let c_owned_ref = a.clone() - &b; + prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense); + prop_assert_eq!(c_owned_ref.pattern(), &c_pattern); + + let c_ref_owned = &a - b.clone(); + prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense); + prop_assert_eq!(c_ref_owned.pattern(), &c_pattern); + + let c_ref_ref = &a - &b; + prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); + prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); + } + #[test] fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) { let dense = DMatrix::from(&matrix);