diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index 51d4f3bb..f20dbd52 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -1,7 +1,7 @@ use crate::csr::CsrMatrix; use crate::csc::CscMatrix; -use std::ops::{Add, Mul}; +use std::ops::{Add, Mul, MulAssign}; use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc}; use nalgebra::{ClosedAdd, ClosedMul, Scalar}; @@ -13,17 +13,16 @@ use crate::ops::{Op}; /// See below for usage. macro_rules! impl_bin_op { ($trait:ident, $method:ident, - <$($life:lifetime),*>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block) + <$($life:lifetime),* $(,)? $($scalar_type:ident)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block) => { - impl<$($life,)* T> $trait<$b_type> for $a_type + impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type where - T: Scalar + ClosedAdd + ClosedMul + Zero + One + $($scalar_type: Scalar + ClosedAdd + ClosedMul + Zero + One)? { type Output = $ret; - fn $method(self, rhs: $b_type) -> Self::Output { + fn $method(self, $b: $b_type) -> Self::Output { let $a = self; - let $b = rhs; $body } } @@ -40,7 +39,7 @@ macro_rules! impl_add { /// CsrMatrix or CscMatrix. macro_rules! impl_spadd { ($matrix_type:ident, $spadd_fn:ident) => { - impl_add!(<'a>(a: &'a $matrix_type, b: &'a $matrix_type) -> $matrix_type { + impl_add!(<'a, T>(a: &'a $matrix_type, b: &'a $matrix_type) -> $matrix_type { // If both matrices have the same pattern, then we can immediately re-use it let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) { Arc::clone(a.pattern()) @@ -56,7 +55,7 @@ macro_rules! impl_spadd { result }); - impl_add!(<'a>(a: $matrix_type, b: &'a $matrix_type) -> $matrix_type { + impl_add!(<'a, T>(a: $matrix_type, b: &'a $matrix_type) -> $matrix_type { let mut a = a; if Arc::ptr_eq(a.pattern(), b.pattern()) { $spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap(); @@ -66,10 +65,10 @@ macro_rules! impl_spadd { } }); - impl_add!(<'a>(a: &'a $matrix_type, b: $matrix_type) -> $matrix_type { + impl_add!(<'a, T>(a: &'a $matrix_type, b: $matrix_type) -> $matrix_type { b + a }); - impl_add!(<>(a: $matrix_type, b: $matrix_type) -> $matrix_type { + impl_add!((a: $matrix_type, b: $matrix_type) -> $matrix_type { a + &b }); } @@ -88,7 +87,7 @@ macro_rules! impl_mul { /// CsrMatrix or CscMatrix. macro_rules! impl_spmm { ($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => { - impl_mul!(<'a>(a: &'a $matrix_type, b: &'a $matrix_type) -> $matrix_type { + impl_mul!(<'a, T>(a: &'a $matrix_type, b: &'a $matrix_type) -> $matrix_type { let pattern = $pattern_fn(a.pattern(), b.pattern()); let values = vec![T::zero(); pattern.nnz()]; let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values) @@ -101,12 +100,85 @@ macro_rules! impl_spmm { .expect("Internal error: spmm failed (please debug)."); result }); - impl_mul!(<'a>(a: &'a $matrix_type, b: $matrix_type) -> $matrix_type { a * &b}); - impl_mul!(<'a>(a: $matrix_type, b: &'a $matrix_type) -> $matrix_type { &a * b}); - impl_mul!(<>(a: $matrix_type, b: $matrix_type) -> $matrix_type { &a * &b}); + impl_mul!(<'a, T>(a: &'a $matrix_type, b: $matrix_type) -> $matrix_type { a * &b}); + impl_mul!(<'a, T>(a: $matrix_type, b: &'a $matrix_type) -> $matrix_type { &a * b}); + impl_mul!((a: $matrix_type, b: $matrix_type) -> $matrix_type { &a * &b}); } } impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc); // Need to switch order of operations for CSC pattern -impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc); \ No newline at end of file +impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc); + +/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary +/// is that we are not able to implement Mul> for all T generically due to orphan rules. +macro_rules! impl_concrete_scalar_matrix_mul { + ($matrix_type:ident, $($scalar_type:ty),*) => { + // For each concrete scalar type, forward the implementation of scalar * matrix + // to matrix * scalar, which we have already implemented through generics + $( + impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>) + -> $matrix_type<$scalar_type> { b * a }); + impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>) + -> $matrix_type<$scalar_type> { b * a }); + impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>) + -> $matrix_type<$scalar_type> { b * (*a) }); + impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>) + -> $matrix_type<$scalar_type> { b * *a }); + )* + } +} + +/// Implements multiplication between matrix and scalar for various matrix types +macro_rules! impl_scalar_mul { + ($matrix_type: ident) => { + impl_mul!(<'a, T>(a: &'a $matrix_type, b: &'a T) -> $matrix_type { + let values: Vec<_> = a.values() + .iter() + .map(|v_i| v_i.inlined_clone() * b.inlined_clone()) + .collect(); + $matrix_type::try_from_pattern_and_values(Arc::clone(a.pattern()), values).unwrap() + }); + impl_mul!(<'a, T>(a: &'a $matrix_type, b: T) -> $matrix_type { + a * &b + }); + impl_mul!(<'a, T>(a: $matrix_type, b: &'a T) -> $matrix_type { + let mut a = a; + for value in a.values_mut() { + *value = b.inlined_clone() * value.inlined_clone(); + } + a + }); + impl_mul!((a: $matrix_type, b: T) -> $matrix_type { + a * &b + }); + impl_concrete_scalar_matrix_mul!( + $matrix_type, + i8, i16, i32, i64, u8, u16, u32, u64, isize, usize, f32, f64); + + impl MulAssign for $matrix_type + where + T: Scalar + ClosedAdd + ClosedMul + Zero + One + { + fn mul_assign(&mut self, scalar: T) { + for val in self.values_mut() { + *val *= scalar.inlined_clone(); + } + } + } + + impl<'a, T> MulAssign<&'a T> for $matrix_type + where + T: Scalar + ClosedAdd + ClosedMul + Zero + One + { + fn mul_assign(&mut self, scalar: &'a T) { + for val in self.values_mut() { + *val *= scalar.inlined_clone(); + } + } + } + } +} + +impl_scalar_mul!(CsrMatrix); +impl_scalar_mul!(CscMatrix); \ No newline at end of file diff --git a/nalgebra-sparse/tests/common/mod.rs b/nalgebra-sparse/tests/common/mod.rs index 7a559b2f..2b39ab93 100644 --- a/nalgebra-sparse/tests/common/mod.rs +++ b/nalgebra-sparse/tests/common/mod.rs @@ -3,6 +3,8 @@ use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::{csr, csc}; use nalgebra_sparse::csc::CscMatrix; use std::ops::RangeInclusive; +use std::convert::{TryFrom}; +use std::fmt::Debug; #[macro_export] macro_rules! assert_panics { @@ -29,6 +31,15 @@ pub const PROPTEST_MATRIX_DIM: RangeInclusive = 0..=6; pub const PROPTEST_MAX_NNZ: usize = 40; pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive = -5 ..= 5; +pub fn value_strategy() -> RangeInclusive +where + T: TryFrom, + T::Error: Debug +{ + let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end()); + T::try_from(*start).unwrap() ..= T::try_from(*end).unwrap() +} + pub fn csr_strategy() -> impl Strategy> { csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) } diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index fddc5045..3cbee92f 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -280,7 +280,6 @@ fn dense_gemm<'a>(beta: i32, } proptest! { - #[test] fn spmm_csr_dense_agrees_with_dense_result( SpmmCsrDenseArgs { c, beta, alpha, a, b } @@ -837,4 +836,94 @@ proptest! { prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); } + + #[test] + fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) { + let dense = DMatrix::from(&matrix); + let dense_result = dense * scalar; + + let result_owned_owned = matrix.clone() * scalar; + let result_owned_ref = matrix.clone() * &scalar; + let result_ref_owned = &matrix * scalar; + let result_ref_ref = &matrix * &scalar; + + // Check that all the combinations of reference and owned variables return the same + // result + prop_assert_eq!(&result_owned_ref, &result_owned_owned); + prop_assert_eq!(&result_ref_owned, &result_owned_owned); + prop_assert_eq!(&result_ref_ref, &result_owned_owned); + + // Check that this result is consistent with the dense result, and that the + // NNZ is the same as before + prop_assert_eq!(result_owned_owned.nnz(), matrix.nnz()); + prop_assert_eq!(DMatrix::from(&result_owned_owned), dense_result); + + // Finally, check mul-assign + let mut result_assign_owned = matrix.clone(); + result_assign_owned *= scalar; + let mut result_assign_ref = matrix.clone(); + result_assign_ref *= &scalar; + + prop_assert_eq!(&result_assign_owned, &result_owned_owned); + prop_assert_eq!(&result_assign_ref, &result_owned_owned); + } + + #[test] + fn csc_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csc_strategy())) { + let dense = DMatrix::from(&matrix); + let dense_result = dense * scalar; + + let result_owned_owned = matrix.clone() * scalar; + let result_owned_ref = matrix.clone() * &scalar; + let result_ref_owned = &matrix * scalar; + let result_ref_ref = &matrix * &scalar; + + // Check that all the combinations of reference and owned variables return the same + // result + prop_assert_eq!(&result_owned_ref, &result_owned_owned); + prop_assert_eq!(&result_ref_owned, &result_owned_owned); + prop_assert_eq!(&result_ref_ref, &result_owned_owned); + + // Check that this result is consistent with the dense result, and that the + // NNZ is the same as before + prop_assert_eq!(result_owned_owned.nnz(), matrix.nnz()); + prop_assert_eq!(DMatrix::from(&result_owned_owned), dense_result); + + // Finally, check mul-assign + let mut result_assign_owned = matrix.clone(); + result_assign_owned *= scalar; + let mut result_assign_ref = matrix.clone(); + result_assign_ref *= &scalar; + + prop_assert_eq!(&result_assign_owned, &result_owned_owned); + prop_assert_eq!(&result_assign_ref, &result_owned_owned); + } + + #[test] + fn scalar_mul_csr((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) { + // For scalar * matrix, we cannot generally implement this for any type T, + // so we have implemented this for the built in types separately. This requires + // us to also test these types separately. For validation, we check that + // scalar * matrix == matrix * scalar, + // which is sufficient for correctness if matrix * scalar is correctly implemented + // (which is tested separately). + // We only test for i32 here, because with our current implementation, the implementations + // for different types are completely identical and only rely on basic arithmetic + // operations + let result = &matrix * scalar; + prop_assert_eq!(&(scalar * matrix.clone()), &result); + prop_assert_eq!(&(scalar * &matrix), &result); + prop_assert_eq!(&(&scalar * matrix.clone()), &result); + prop_assert_eq!(&(&scalar * &matrix), &result); + } + + #[test] + fn scalar_mul_csc((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csc_strategy())) { + // See comments for scalar_mul_csr + let result = &matrix * scalar; + prop_assert_eq!(&(scalar * matrix.clone()), &result); + prop_assert_eq!(&(scalar * &matrix), &result); + prop_assert_eq!(&(&scalar * matrix.clone()), &result); + prop_assert_eq!(&(&scalar * &matrix), &result); + } } \ No newline at end of file