forked from M-Labs/nalgebra
Implement Neg, Div, DivAssign for Csr/CscMatrix
This commit is contained in:
parent
0b4356eb0e
commit
b7a7f967b8
@ -1,10 +1,10 @@
|
|||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Mul, MulAssign, Sub, Neg};
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
||||||
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
@ -13,15 +13,15 @@ use crate::ops::{Op};
|
|||||||
/// See below for usage.
|
/// See below for usage.
|
||||||
macro_rules! impl_bin_op {
|
macro_rules! impl_bin_op {
|
||||||
($trait:ident, $method:ident,
|
($trait:ident, $method:ident,
|
||||||
<$($life:lifetime),* $(,)? $($scalar_type:ident)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
<$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
||||||
=>
|
=>
|
||||||
{
|
{
|
||||||
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
||||||
where
|
where
|
||||||
// Note: The Signed bound is currently required because we delegate e.g.
|
// Note: The Neg bound is currently required because we delegate e.g.
|
||||||
// Sub to SpAdd with negative coefficients. This is not well-defined for
|
// Sub to SpAdd with negative coefficients. This is not well-defined for
|
||||||
// unsigned data types.
|
// unsigned data types.
|
||||||
$($scalar_type: Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
|
$($scalar_type: $($bounds + )? Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
|
||||||
{
|
{
|
||||||
type Output = $ret;
|
type Output = $ret;
|
||||||
fn $method(self, $b: $b_type) -> Self::Output {
|
fn $method(self, $b: $b_type) -> Self::Output {
|
||||||
@ -29,7 +29,7 @@ macro_rules! impl_bin_op {
|
|||||||
$body
|
$body
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implements a +/- b for all combinations of reference and owned matrices, for
|
/// Implements a +/- b for all combinations of reference and owned matrices, for
|
||||||
@ -198,4 +198,81 @@ macro_rules! impl_scalar_mul {
|
|||||||
impl_scalar_mul!(CsrMatrix);
|
impl_scalar_mul!(CsrMatrix);
|
||||||
impl_scalar_mul!(CscMatrix);
|
impl_scalar_mul!(CscMatrix);
|
||||||
|
|
||||||
// TODO: Neg, Div
|
macro_rules! impl_neg {
|
||||||
|
($matrix_type:ident) => {
|
||||||
|
impl<T> Neg for $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Neg<Output=T>
|
||||||
|
{
|
||||||
|
type Output = $matrix_type<T>;
|
||||||
|
|
||||||
|
fn neg(mut self) -> Self::Output {
|
||||||
|
for v_i in self.values_mut() {
|
||||||
|
*v_i = -v_i.inlined_clone();
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Neg for &'a $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + Neg<Output=T>
|
||||||
|
{
|
||||||
|
type Output = $matrix_type<T>;
|
||||||
|
|
||||||
|
fn neg(self) -> Self::Output {
|
||||||
|
// TODO: This is inefficient. Ideally we'd have a method that would let us
|
||||||
|
// obtain both the sparsity pattern and values from the matrix,
|
||||||
|
// and then modify the values before creating a new matrix from the pattern
|
||||||
|
// and negated values.
|
||||||
|
- self.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_neg!(CsrMatrix);
|
||||||
|
impl_neg!(CscMatrix);
|
||||||
|
|
||||||
|
macro_rules! impl_div {
|
||||||
|
($matrix_type:ident) => {
|
||||||
|
impl_bin_op!(Div, div, <T: ClosedDiv>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
|
||||||
|
let mut matrix = matrix;
|
||||||
|
matrix /= scalar;
|
||||||
|
matrix
|
||||||
|
});
|
||||||
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
|
||||||
|
matrix / scalar.inlined_clone()
|
||||||
|
});
|
||||||
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
|
||||||
|
let new_values = matrix.values()
|
||||||
|
.iter()
|
||||||
|
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
|
||||||
|
.collect();
|
||||||
|
$matrix_type::try_from_pattern_and_values(Arc::clone(matrix.pattern()), new_values)
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
|
||||||
|
matrix / scalar.inlined_clone()
|
||||||
|
});
|
||||||
|
|
||||||
|
impl<T> DivAssign<T> for $matrix_type<T>
|
||||||
|
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
|
||||||
|
{
|
||||||
|
fn div_assign(&mut self, scalar: T) {
|
||||||
|
self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.inlined_clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
|
||||||
|
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
|
||||||
|
{
|
||||||
|
fn div_assign(&mut self, scalar: &'a T) {
|
||||||
|
*self /= scalar.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_div!(CsrMatrix);
|
||||||
|
impl_div!(CscMatrix);
|
@ -40,6 +40,15 @@ where
|
|||||||
T::try_from(*start).unwrap() ..= T::try_from(*end).unwrap()
|
T::try_from(*start).unwrap() ..= T::try_from(*end).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value=i32> {
|
||||||
|
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
|
||||||
|
assert!(start < &0);
|
||||||
|
assert!(end > &0);
|
||||||
|
// Note: we don't use RangeInclusive for the second range, because then we'd have different
|
||||||
|
// types, which would require boxing
|
||||||
|
(*start .. 0).prop_union(1 .. *end + 1)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
PROPTEST_I32_VALUE_STRATEGY};
|
PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy};
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
|
||||||
spadd_csr_prealloc, spadd_csc_prealloc,
|
spadd_csr_prealloc, spadd_csc_prealloc,
|
||||||
spmm_csr_prealloc, spmm_csc_prealloc};
|
spmm_csr_prealloc, spmm_csc_prealloc};
|
||||||
@ -992,4 +992,99 @@ proptest! {
|
|||||||
prop_assert_eq!(&(&scalar * matrix.clone()), &result);
|
prop_assert_eq!(&(&scalar * matrix.clone()), &result);
|
||||||
prop_assert_eq!(&(&scalar * &matrix), &result);
|
prop_assert_eq!(&(&scalar * &matrix), &result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_neg(csr in csr_strategy()) {
|
||||||
|
let result = &csr - 2 * &csr;
|
||||||
|
prop_assert_eq!(-&csr, result.clone());
|
||||||
|
prop_assert_eq!(-csr, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_neg(csc in csc_strategy()) {
|
||||||
|
let result = &csc - 2 * &csc;
|
||||||
|
prop_assert_eq!(-&csc, result.clone());
|
||||||
|
prop_assert_eq!(-csc, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_div((csr, divisor) in (csr_strategy(), non_zero_i32_value_strategy())) {
|
||||||
|
let result_owned_owned = csr.clone() / divisor;
|
||||||
|
let result_owned_ref = csr.clone() / &divisor;
|
||||||
|
let result_ref_owned = &csr / divisor;
|
||||||
|
let result_ref_ref = &csr / &divisor;
|
||||||
|
|
||||||
|
// Verify that all results are the same
|
||||||
|
prop_assert_eq!(&result_owned_ref, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_owned, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_ref, &result_owned_owned);
|
||||||
|
|
||||||
|
// Check that NNZ was left unchanged
|
||||||
|
prop_assert_eq!(result_owned_owned.nnz(), csr.nnz());
|
||||||
|
|
||||||
|
// Then compare against the equivalent dense result
|
||||||
|
let dense_result = DMatrix::from(&csr) / divisor;
|
||||||
|
prop_assert_eq!(DMatrix::from(&result_owned_owned), dense_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_div((csc, divisor) in (csc_strategy(), non_zero_i32_value_strategy())) {
|
||||||
|
let result_owned_owned = csc.clone() / divisor;
|
||||||
|
let result_owned_ref = csc.clone() / &divisor;
|
||||||
|
let result_ref_owned = &csc / divisor;
|
||||||
|
let result_ref_ref = &csc / &divisor;
|
||||||
|
|
||||||
|
// Verify that all results are the same
|
||||||
|
prop_assert_eq!(&result_owned_ref, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_owned, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_ref, &result_owned_owned);
|
||||||
|
|
||||||
|
// Check that NNZ was left unchanged
|
||||||
|
prop_assert_eq!(result_owned_owned.nnz(), csc.nnz());
|
||||||
|
|
||||||
|
// Then compare against the equivalent dense result
|
||||||
|
let dense_result = DMatrix::from(&csc) / divisor;
|
||||||
|
prop_assert_eq!(DMatrix::from(&result_owned_owned), dense_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_div_assign((csr, divisor) in (csr_strategy(), non_zero_i32_value_strategy())) {
|
||||||
|
let result_owned = {
|
||||||
|
let mut csr = csr.clone();
|
||||||
|
csr /= divisor;
|
||||||
|
csr
|
||||||
|
};
|
||||||
|
|
||||||
|
let result_ref = {
|
||||||
|
let mut csr = csr.clone();
|
||||||
|
csr /= &divisor;
|
||||||
|
csr
|
||||||
|
};
|
||||||
|
|
||||||
|
let expected_result = csr / divisor;
|
||||||
|
|
||||||
|
prop_assert_eq!(&result_owned, &expected_result);
|
||||||
|
prop_assert_eq!(&result_ref, &expected_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_div_assign((csc, divisor) in (csc_strategy(), non_zero_i32_value_strategy())) {
|
||||||
|
let result_owned = {
|
||||||
|
let mut csc = csc.clone();
|
||||||
|
csc /= divisor;
|
||||||
|
csc
|
||||||
|
};
|
||||||
|
|
||||||
|
let result_ref = {
|
||||||
|
let mut csc = csc.clone();
|
||||||
|
csc /= &divisor;
|
||||||
|
csc
|
||||||
|
};
|
||||||
|
|
||||||
|
let expected_result = csc / divisor;
|
||||||
|
|
||||||
|
prop_assert_eq!(&result_owned, &expected_result);
|
||||||
|
prop_assert_eq!(&result_ref, &expected_result);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user