Implement Sub for Csr/CscMatrix
This commit is contained in:
parent
7aeb663165
commit
0b4356eb0e
|
@ -1,10 +1,10 @@
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Mul, MulAssign};
|
use std::ops::{Add, Mul, MulAssign, Sub, Neg};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
||||||
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
|
@ -18,7 +18,10 @@ macro_rules! impl_bin_op {
|
||||||
{
|
{
|
||||||
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
||||||
where
|
where
|
||||||
$($scalar_type: Scalar + ClosedAdd + ClosedMul + Zero + One)?
|
// Note: The Signed bound is currently required because we delegate e.g.
|
||||||
|
// Sub to SpAdd with negative coefficients. This is not well-defined for
|
||||||
|
// unsigned data types.
|
||||||
|
$($scalar_type: Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
|
||||||
{
|
{
|
||||||
type Output = $ret;
|
type Output = $ret;
|
||||||
fn $method(self, $b: $b_type) -> Self::Output {
|
fn $method(self, $b: $b_type) -> Self::Output {
|
||||||
|
@ -29,17 +32,19 @@ macro_rules! impl_bin_op {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! impl_add {
|
/// Implements a +/- b for all combinations of reference and owned matrices, for
|
||||||
($($args:tt)*) => {
|
|
||||||
impl_bin_op!(Add, add, $($args)*);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Implements a + b for all combinations of reference and owned matrices, for
|
|
||||||
/// CsrMatrix or CscMatrix.
|
/// CsrMatrix or CscMatrix.
|
||||||
macro_rules! impl_spadd {
|
macro_rules! impl_sp_plus_minus {
|
||||||
($matrix_type:ident, $spadd_fn:ident) => {
|
// We first match on some special-case syntax, and forward to the actual implementation
|
||||||
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
($matrix_type:ident, $spadd_fn:ident, +) => {
|
||||||
|
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
|
||||||
|
};
|
||||||
|
($matrix_type:ident, $spadd_fn:ident, -) => {
|
||||||
|
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
|
||||||
|
};
|
||||||
|
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
|
||||||
|
impl_bin_op!($trait, $method,
|
||||||
|
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
// If both matrices have the same pattern, then we can immediately re-use it
|
// If both matrices have the same pattern, then we can immediately re-use it
|
||||||
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
Arc::clone(a.pattern())
|
Arc::clone(a.pattern())
|
||||||
|
@ -51,31 +56,41 @@ macro_rules! impl_spadd {
|
||||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
|
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
|
||||||
$spadd_fn(T::one(), &mut result, T::one(), Op::NoOp(&b)).unwrap();
|
$spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
|
||||||
result
|
result
|
||||||
});
|
});
|
||||||
|
|
||||||
impl_add!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
impl_bin_op!($trait, $method,
|
||||||
|
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
let mut a = a;
|
let mut a = a;
|
||||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap();
|
$spadd_fn(T::one(), &mut a, $factor * T::one(), Op::NoOp(b)).unwrap();
|
||||||
a
|
a
|
||||||
} else {
|
} else {
|
||||||
&a + b
|
&a $sign b
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
impl_bin_op!($trait, $method,
|
||||||
b + a
|
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
let mut b = b;
|
||||||
|
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
|
$spadd_fn($factor * T::one(), &mut b, T::one(), Op::NoOp(a)).unwrap();
|
||||||
|
b
|
||||||
|
} else {
|
||||||
|
a $sign &b
|
||||||
|
}
|
||||||
});
|
});
|
||||||
impl_add!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
a + &b
|
a $sign &b
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_spadd!(CsrMatrix, spadd_csr_prealloc);
|
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
|
||||||
impl_spadd!(CscMatrix, spadd_csc_prealloc);
|
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
|
||||||
|
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
|
||||||
|
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
|
||||||
|
|
||||||
macro_rules! impl_mul {
|
macro_rules! impl_mul {
|
||||||
($($args:tt)*) => {
|
($($args:tt)*) => {
|
||||||
|
@ -154,7 +169,7 @@ macro_rules! impl_scalar_mul {
|
||||||
});
|
});
|
||||||
impl_concrete_scalar_matrix_mul!(
|
impl_concrete_scalar_matrix_mul!(
|
||||||
$matrix_type,
|
$matrix_type,
|
||||||
i8, i16, i32, i64, u8, u16, u32, u64, isize, usize, f32, f64);
|
i8, i16, i32, i64, isize, f32, f64);
|
||||||
|
|
||||||
impl<T> MulAssign<T> for $matrix_type<T>
|
impl<T> MulAssign<T> for $matrix_type<T>
|
||||||
where
|
where
|
||||||
|
@ -182,3 +197,5 @@ macro_rules! impl_scalar_mul {
|
||||||
|
|
||||||
impl_scalar_mul!(CsrMatrix);
|
impl_scalar_mul!(CsrMatrix);
|
||||||
impl_scalar_mul!(CscMatrix);
|
impl_scalar_mul!(CscMatrix);
|
||||||
|
|
||||||
|
// TODO: Neg, Div
|
|
@ -417,6 +417,39 @@ proptest! {
|
||||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_sub_csr(
|
||||||
|
// a and b have the same dimensions
|
||||||
|
(a, b)
|
||||||
|
in csr_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// See comments in csr_add_csr for rationale for checking the pattern this way
|
||||||
|
let c_dense = DMatrix::from(&a) - DMatrix::from(&b);
|
||||||
|
let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern());
|
||||||
|
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
|
||||||
|
|
||||||
|
// Check each combination of owned matrices and references
|
||||||
|
let c_owned_owned = a.clone() - b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_owned_ref = a.clone() - &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_owned = &a - b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_ref = &a - &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
|
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
|
||||||
{
|
{
|
||||||
|
@ -837,6 +870,39 @@ proptest! {
|
||||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_sub_csc(
|
||||||
|
// a and b have the same dimensions
|
||||||
|
(a, b)
|
||||||
|
in csc_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// See comments in csc_add_csc for rationale for checking the pattern this way
|
||||||
|
let c_dense = DMatrix::from(&a) - DMatrix::from(&b);
|
||||||
|
let c_dense_pattern = dense_csc_pattern(a.pattern()) + dense_csc_pattern(b.pattern());
|
||||||
|
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
|
||||||
|
|
||||||
|
// Check each combination of owned matrices and references
|
||||||
|
let c_owned_owned = a.clone() - b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_owned_ref = a.clone() - &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_owned = &a - b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_ref = &a - &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) {
|
fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) {
|
||||||
let dense = DMatrix::from(&matrix);
|
let dense = DMatrix::from(&matrix);
|
||||||
|
|
Loading…
Reference in New Issue