forked from M-Labs/nalgebra
Implement matrix-scalar multiplication
This commit is contained in:
parent
dbdf5567fc
commit
7aeb663165
@ -1,7 +1,7 @@
|
|||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul, MulAssign};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
||||||
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
@ -13,17 +13,16 @@ use crate::ops::{Op};
|
|||||||
/// See below for usage.
|
/// See below for usage.
|
||||||
macro_rules! impl_bin_op {
|
macro_rules! impl_bin_op {
|
||||||
($trait:ident, $method:ident,
|
($trait:ident, $method:ident,
|
||||||
<$($life:lifetime),*>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
<$($life:lifetime),* $(,)? $($scalar_type:ident)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
||||||
=>
|
=>
|
||||||
{
|
{
|
||||||
impl<$($life,)* T> $trait<$b_type> for $a_type
|
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
$($scalar_type: Scalar + ClosedAdd + ClosedMul + Zero + One)?
|
||||||
{
|
{
|
||||||
type Output = $ret;
|
type Output = $ret;
|
||||||
fn $method(self, rhs: $b_type) -> Self::Output {
|
fn $method(self, $b: $b_type) -> Self::Output {
|
||||||
let $a = self;
|
let $a = self;
|
||||||
let $b = rhs;
|
|
||||||
$body
|
$body
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -40,7 +39,7 @@ macro_rules! impl_add {
|
|||||||
/// CsrMatrix or CscMatrix.
|
/// CsrMatrix or CscMatrix.
|
||||||
macro_rules! impl_spadd {
|
macro_rules! impl_spadd {
|
||||||
($matrix_type:ident, $spadd_fn:ident) => {
|
($matrix_type:ident, $spadd_fn:ident) => {
|
||||||
impl_add!(<'a>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
// If both matrices have the same pattern, then we can immediately re-use it
|
// If both matrices have the same pattern, then we can immediately re-use it
|
||||||
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
Arc::clone(a.pattern())
|
Arc::clone(a.pattern())
|
||||||
@ -56,7 +55,7 @@ macro_rules! impl_spadd {
|
|||||||
result
|
result
|
||||||
});
|
});
|
||||||
|
|
||||||
impl_add!(<'a>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
impl_add!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
let mut a = a;
|
let mut a = a;
|
||||||
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap();
|
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap();
|
||||||
@ -66,10 +65,10 @@ macro_rules! impl_spadd {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
impl_add!(<'a>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
impl_add!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
b + a
|
b + a
|
||||||
});
|
});
|
||||||
impl_add!(<>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
impl_add!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
a + &b
|
a + &b
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -88,7 +87,7 @@ macro_rules! impl_mul {
|
|||||||
/// CsrMatrix or CscMatrix.
|
/// CsrMatrix or CscMatrix.
|
||||||
macro_rules! impl_spmm {
|
macro_rules! impl_spmm {
|
||||||
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
|
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
|
||||||
impl_mul!(<'a>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values)
|
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
@ -101,12 +100,85 @@ macro_rules! impl_spmm {
|
|||||||
.expect("Internal error: spmm failed (please debug).");
|
.expect("Internal error: spmm failed (please debug).");
|
||||||
result
|
result
|
||||||
});
|
});
|
||||||
impl_mul!(<'a>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
|
||||||
impl_mul!(<'a>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
|
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
|
||||||
impl_mul!(<>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
|
impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
|
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
|
||||||
// Need to switch order of operations for CSC pattern
|
// Need to switch order of operations for CSC pattern
|
||||||
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);
|
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);
|
||||||
|
|
||||||
|
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
|
||||||
|
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
|
||||||
|
macro_rules! impl_concrete_scalar_matrix_mul {
|
||||||
|
($matrix_type:ident, $($scalar_type:ty),*) => {
|
||||||
|
// For each concrete scalar type, forward the implementation of scalar * matrix
|
||||||
|
// to matrix * scalar, which we have already implemented through generics
|
||||||
|
$(
|
||||||
|
impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * a });
|
||||||
|
impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * a });
|
||||||
|
impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * (*a) });
|
||||||
|
impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
|
||||||
|
-> $matrix_type<$scalar_type> { b * *a });
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implements multiplication between matrix and scalar for various matrix types
|
||||||
|
macro_rules! impl_scalar_mul {
|
||||||
|
($matrix_type: ident) => {
|
||||||
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
|
||||||
|
let values: Vec<_> = a.values()
|
||||||
|
.iter()
|
||||||
|
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
|
||||||
|
.collect();
|
||||||
|
$matrix_type::try_from_pattern_and_values(Arc::clone(a.pattern()), values).unwrap()
|
||||||
|
});
|
||||||
|
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||||
|
a * &b
|
||||||
|
});
|
||||||
|
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
|
||||||
|
let mut a = a;
|
||||||
|
for value in a.values_mut() {
|
||||||
|
*value = b.inlined_clone() * value.inlined_clone();
|
||||||
|
}
|
||||||
|
a
|
||||||
|
});
|
||||||
|
impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||||
|
a * &b
|
||||||
|
});
|
||||||
|
impl_concrete_scalar_matrix_mul!(
|
||||||
|
$matrix_type,
|
||||||
|
i8, i16, i32, i64, u8, u16, u32, u64, isize, usize, f32, f64);
|
||||||
|
|
||||||
|
impl<T> MulAssign<T> for $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
fn mul_assign(&mut self, scalar: T) {
|
||||||
|
for val in self.values_mut() {
|
||||||
|
*val *= scalar.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
fn mul_assign(&mut self, scalar: &'a T) {
|
||||||
|
for val in self.values_mut() {
|
||||||
|
*val *= scalar.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_scalar_mul!(CsrMatrix);
|
||||||
|
impl_scalar_mul!(CscMatrix);
|
@ -3,6 +3,8 @@ use nalgebra_sparse::csr::CsrMatrix;
|
|||||||
use nalgebra_sparse::proptest::{csr, csc};
|
use nalgebra_sparse::proptest::{csr, csc};
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
use std::ops::RangeInclusive;
|
use std::ops::RangeInclusive;
|
||||||
|
use std::convert::{TryFrom};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! assert_panics {
|
macro_rules! assert_panics {
|
||||||
@ -29,6 +31,15 @@ pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
|||||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
||||||
|
|
||||||
|
pub fn value_strategy<T>() -> RangeInclusive<T>
|
||||||
|
where
|
||||||
|
T: TryFrom<i32>,
|
||||||
|
T::Error: Debug
|
||||||
|
{
|
||||||
|
let (start, end) = (PROPTEST_I32_VALUE_STRATEGY.start(), PROPTEST_I32_VALUE_STRATEGY.end());
|
||||||
|
T::try_from(*start).unwrap() ..= T::try_from(*end).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
}
|
}
|
||||||
|
@ -280,7 +280,6 @@ fn dense_gemm<'a>(beta: i32,
|
|||||||
}
|
}
|
||||||
|
|
||||||
proptest! {
|
proptest! {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_dense_agrees_with_dense_result(
|
fn spmm_csr_dense_agrees_with_dense_result(
|
||||||
SpmmCsrDenseArgs { c, beta, alpha, a, b }
|
SpmmCsrDenseArgs { c, beta, alpha, a, b }
|
||||||
@ -837,4 +836,94 @@ proptest! {
|
|||||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) {
|
||||||
|
let dense = DMatrix::from(&matrix);
|
||||||
|
let dense_result = dense * scalar;
|
||||||
|
|
||||||
|
let result_owned_owned = matrix.clone() * scalar;
|
||||||
|
let result_owned_ref = matrix.clone() * &scalar;
|
||||||
|
let result_ref_owned = &matrix * scalar;
|
||||||
|
let result_ref_ref = &matrix * &scalar;
|
||||||
|
|
||||||
|
// Check that all the combinations of reference and owned variables return the same
|
||||||
|
// result
|
||||||
|
prop_assert_eq!(&result_owned_ref, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_owned, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_ref, &result_owned_owned);
|
||||||
|
|
||||||
|
// Check that this result is consistent with the dense result, and that the
|
||||||
|
// NNZ is the same as before
|
||||||
|
prop_assert_eq!(result_owned_owned.nnz(), matrix.nnz());
|
||||||
|
prop_assert_eq!(DMatrix::from(&result_owned_owned), dense_result);
|
||||||
|
|
||||||
|
// Finally, check mul-assign
|
||||||
|
let mut result_assign_owned = matrix.clone();
|
||||||
|
result_assign_owned *= scalar;
|
||||||
|
let mut result_assign_ref = matrix.clone();
|
||||||
|
result_assign_ref *= &scalar;
|
||||||
|
|
||||||
|
prop_assert_eq!(&result_assign_owned, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_assign_ref, &result_owned_owned);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_mul_scalar((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csc_strategy())) {
|
||||||
|
let dense = DMatrix::from(&matrix);
|
||||||
|
let dense_result = dense * scalar;
|
||||||
|
|
||||||
|
let result_owned_owned = matrix.clone() * scalar;
|
||||||
|
let result_owned_ref = matrix.clone() * &scalar;
|
||||||
|
let result_ref_owned = &matrix * scalar;
|
||||||
|
let result_ref_ref = &matrix * &scalar;
|
||||||
|
|
||||||
|
// Check that all the combinations of reference and owned variables return the same
|
||||||
|
// result
|
||||||
|
prop_assert_eq!(&result_owned_ref, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_owned, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_ref_ref, &result_owned_owned);
|
||||||
|
|
||||||
|
// Check that this result is consistent with the dense result, and that the
|
||||||
|
// NNZ is the same as before
|
||||||
|
prop_assert_eq!(result_owned_owned.nnz(), matrix.nnz());
|
||||||
|
prop_assert_eq!(DMatrix::from(&result_owned_owned), dense_result);
|
||||||
|
|
||||||
|
// Finally, check mul-assign
|
||||||
|
let mut result_assign_owned = matrix.clone();
|
||||||
|
result_assign_owned *= scalar;
|
||||||
|
let mut result_assign_ref = matrix.clone();
|
||||||
|
result_assign_ref *= &scalar;
|
||||||
|
|
||||||
|
prop_assert_eq!(&result_assign_owned, &result_owned_owned);
|
||||||
|
prop_assert_eq!(&result_assign_ref, &result_owned_owned);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scalar_mul_csr((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csr_strategy())) {
|
||||||
|
// For scalar * matrix, we cannot generally implement this for any type T,
|
||||||
|
// so we have implemented this for the built in types separately. This requires
|
||||||
|
// us to also test these types separately. For validation, we check that
|
||||||
|
// scalar * matrix == matrix * scalar,
|
||||||
|
// which is sufficient for correctness if matrix * scalar is correctly implemented
|
||||||
|
// (which is tested separately).
|
||||||
|
// We only test for i32 here, because with our current implementation, the implementations
|
||||||
|
// for different types are completely identical and only rely on basic arithmetic
|
||||||
|
// operations
|
||||||
|
let result = &matrix * scalar;
|
||||||
|
prop_assert_eq!(&(scalar * matrix.clone()), &result);
|
||||||
|
prop_assert_eq!(&(scalar * &matrix), &result);
|
||||||
|
prop_assert_eq!(&(&scalar * matrix.clone()), &result);
|
||||||
|
prop_assert_eq!(&(&scalar * &matrix), &result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scalar_mul_csc((scalar, matrix) in (PROPTEST_I32_VALUE_STRATEGY, csc_strategy())) {
|
||||||
|
// See comments for scalar_mul_csr
|
||||||
|
let result = &matrix * scalar;
|
||||||
|
prop_assert_eq!(&(scalar * matrix.clone()), &result);
|
||||||
|
prop_assert_eq!(&(scalar * &matrix), &result);
|
||||||
|
prop_assert_eq!(&(&scalar * matrix.clone()), &result);
|
||||||
|
prop_assert_eq!(&(&scalar * &matrix), &result);
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user