From fe8592fde1752da91cdabb410c85ade51153ab10 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Mon, 21 Dec 2020 15:09:29 +0100 Subject: [PATCH] Refactor ops to use new Op type instead of separate Transpose flag --- nalgebra-sparse/src/ops/impl_std_ops.rs | 14 +- nalgebra-sparse/src/ops/mod.rs | 52 ++++++- nalgebra-sparse/src/ops/serial/csr.rs | 186 ++++++++++++----------- nalgebra-sparse/src/ops/serial/mod.rs | 65 ++++---- nalgebra-sparse/tests/unit_tests/ops.rs | 193 ++++++++++++++---------- 5 files changed, 299 insertions(+), 211 deletions(-) diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index 34a7bcf5..c181464d 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -5,7 +5,7 @@ use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr}; use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use num_traits::{Zero, One}; use std::sync::Arc; -use crate::ops::Transpose; +use crate::ops::{Op}; impl<'a, T> Add<&'a CsrMatrix> for &'a CsrMatrix where @@ -21,8 +21,8 @@ where // We are giving data that is valid by definition, so it is safe to unwrap below let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) .unwrap(); - spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap(); - spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap(); + spadd_csr(&mut result, T::zero(), T::one(), Op::NoOp(&self)).unwrap(); + spadd_csr(&mut result, T::one(), T::one(), Op::NoOp(&rhs)).unwrap(); result } } @@ -35,7 +35,7 @@ where fn add(mut self, rhs: &'a CsrMatrix) -> Self::Output { if Arc::ptr_eq(self.pattern(), rhs.pattern()) { - spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap(); + spadd_csr(&mut self, T::one(), T::one(), Op::NoOp(rhs)).unwrap(); self } else { &self + rhs @@ -93,10 +93,8 @@ impl_matrix_mul!(<'a>(a: &'a CsrMatrix, b: &'a CsrMatrix) -> CsrMatrix spmm_csr(&mut result, T::zero(), T::one(), - Transpose(false), - a, - Transpose(false), - b) + Op::NoOp(a), + Op::NoOp(b)) .expect("Internal error: spmm failed (please debug)."); result }); diff --git a/nalgebra-sparse/src/ops/mod.rs b/nalgebra-sparse/src/ops/mod.rs index 08939d8a..14a18dc1 100644 --- a/nalgebra-sparse/src/ops/mod.rs +++ b/nalgebra-sparse/src/ops/mod.rs @@ -4,14 +4,54 @@ mod impl_std_ops; pub mod serial; /// TODO -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Transpose(pub bool); - -impl Transpose { +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Op { /// TODO - pub fn to_bool(&self) -> bool { - self.0 + NoOp(T), + /// TODO + Transpose(T), +} + +impl Op { + /// TODO + pub fn inner_ref(&self) -> &T { + match self { + Op::NoOp(obj) => &obj, + Op::Transpose(obj) => &obj + } + } + + /// TODO + pub fn as_ref(&self) -> Op<&T> { + match self { + Op::NoOp(obj) => Op::NoOp(&obj), + Op::Transpose(obj) => Op::Transpose(&obj) + } + } + + /// TODO + pub fn convert(self) -> Op + where T: Into + { + match self { + Op::NoOp(obj) => Op::NoOp(obj.into()), + Op::Transpose(obj) => Op::Transpose(obj.into()) + } + } + + /// TODO + /// TODO: Rewrite the other functions by leveraging this one + pub fn map_same_op U>(self, f: F) -> Op { + match self { + Op::NoOp(obj) => Op::NoOp(f(obj)), + Op::Transpose(obj) => Op::Transpose(f(obj)) + } } } +impl From for Op { + fn from(obj: T) -> Self { + Self::NoOp(obj) + } +} diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index 670711cd..e1b5a1c5 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -1,5 +1,5 @@ use crate::csr::CsrMatrix; -use crate::ops::{Transpose}; +use crate::ops::{Op}; use crate::SparseEntryMut; use crate::ops::serial::{OperationError, OperationErrorType}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; @@ -7,65 +7,71 @@ use num_traits::{Zero, One}; use std::sync::Arc; use std::borrow::Cow; -/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`. +/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`. pub fn spmm_csr_dense<'a, T>(c: impl Into>, beta: T, alpha: T, - trans_a: Transpose, - a: &CsrMatrix, - trans_b: Transpose, - b: impl Into>) + a: Op<&CsrMatrix>, + b: Op>>) where T: Scalar + ClosedAdd + ClosedMul + Zero + One { - spmm_csr_dense_(c.into(), beta, alpha, trans_a, a, trans_b, b.into()) + let b = b.convert(); + spmm_csr_dense_(c.into(), beta, alpha, a, b) } fn spmm_csr_dense_(mut c: DMatrixSliceMut, beta: T, alpha: T, - trans_a: Transpose, - a: &CsrMatrix, - trans_b: Transpose, - b: DMatrixSlice) + a: Op<&CsrMatrix>, + b: Op>) where T: Scalar + ClosedAdd + ClosedMul + Zero + One { - assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b); + assert_compatible_spmm_dims!(c, a, b); - if trans_a.to_bool() { - // In this case, we have to pre-multiply C by beta - c *= beta; + match a { + Op::Transpose(ref a) => { + // In this case, we have to pre-multiply C by beta + c *= beta; - for k in 0..a.nrows() { - let a_row_k = a.row(k); - for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) { - let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone(); - let mut c_row_i = c.row_mut(i); - if trans_b.to_bool() { - let b_col_k = b.column(k); - for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) { - *c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone(); - } - } else { - let b_row_k = b.row(k); - for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) { - *c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone(); + for k in 0..a.nrows() { + let a_row_k = a.row(k); + for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) { + let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone(); + let mut c_row_i = c.row_mut(i); + match b { + Op::NoOp(ref b) => { + let b_row_k = b.row(k); + for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) { + *c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone(); + } + }, + Op::Transpose(ref b) => { + let b_col_k = b.column(k); + for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) { + *c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone(); + } + }, } } } - } - } else { - for j in 0..c.ncols() { - let mut c_col_j = c.column_mut(j); - for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) { - let mut dot_ij = T::zero(); - for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { - let b_contrib = - if trans_b.to_bool() { b.index((j, k)) } else { b.index((k, j)) }; - dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); + }, + Op::NoOp(ref a) => { + for j in 0..c.ncols() { + let mut c_col_j = c.column_mut(j); + for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) { + let mut dot_ij = T::zero(); + for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { + let b_contrib = + match b { + Op::NoOp(ref b) => b.index((k, j)), + Op::Transpose(ref b) => b.index((j, k)) + }; + dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); + } + *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; } - *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; } } } @@ -77,32 +83,31 @@ fn spadd_csr_unexpected_entry() -> OperationError { String::from("Found entry in `a` that is not present in `c`.")) } -/// Sparse matrix addition `C <- beta * C + alpha * trans(A)`. +/// Sparse matrix addition `C <- beta * C + alpha * op(A)`. /// /// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is /// returned. pub fn spadd_csr(c: &mut CsrMatrix, beta: T, alpha: T, - trans_a: Transpose, - a: &CsrMatrix) + a: Op<&CsrMatrix>) -> Result<(), OperationError> where T: Scalar + ClosedAdd + ClosedMul + Zero + One { - assert_compatible_spadd_dims!(c, a, trans_a); + assert_compatible_spadd_dims!(c, a); // TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc` - if Arc::ptr_eq(&c.pattern(), &a.pattern()) { + if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) { // Special fast path: The two matrices have *exactly* the same sparsity pattern, // so we only need to sum the value arrays - for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.values()) { + for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) { let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone()); *c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone(); } Ok(()) } else { - if trans_a.to_bool() + if let Op::Transpose(a) = a { if beta != T::one() { for c_ij in c.values_mut() { @@ -120,7 +125,7 @@ where } } } - } else { + } else if let Op::NoOp(a) = a { for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { if beta != T::one() { for c_ij in c_row_i.values_mut() { @@ -160,56 +165,61 @@ pub fn spmm_csr<'a, T>( c: &mut CsrMatrix, beta: T, alpha: T, - trans_a: Transpose, - a: &CsrMatrix, - trans_b: Transpose, - b: &CsrMatrix) + a: Op<&CsrMatrix>, + b: Op<&CsrMatrix>) -> Result<(), OperationError> where T: Scalar + ClosedAdd + ClosedMul + Zero + One { - assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b); + assert_compatible_spmm_dims!(c, a, b); - if !trans_a.to_bool() && !trans_b.to_bool() { - for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { - for c_ij in c_row_i.values_mut() { - *c_ij = beta.inlined_clone() * c_ij.inlined_clone(); - } + use Op::{NoOp, Transpose}; - for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { - let b_row_k = b.row(k); - let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut(); - let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone(); - for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) { - // Determine the location in C to append the value - let (c_local_idx, _) = c_row_i_cols.iter() - .enumerate() - .find(|(_, c_col)| *c_col == j) - .ok_or_else(spmm_csr_unexpected_entry)?; + match (&a, &b) { + (NoOp(ref a), NoOp(ref b)) => { + for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { + for c_ij in c_row_i.values_mut() { + *c_ij = beta.inlined_clone() * c_ij.inlined_clone(); + } - c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone(); - c_row_i_cols = &c_row_i_cols[c_local_idx ..]; - c_row_i_values = &mut c_row_i_values[c_local_idx ..]; + for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { + let b_row_k = b.row(k); + let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut(); + let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone(); + for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) { + // Determine the location in C to append the value + let (c_local_idx, _) = c_row_i_cols.iter() + .enumerate() + .find(|(_, c_col)| *c_col == j) + .ok_or_else(spmm_csr_unexpected_entry)?; + + c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone(); + c_row_i_cols = &c_row_i_cols[c_local_idx ..]; + c_row_i_values = &mut c_row_i_values[c_local_idx ..]; + } } } - } - Ok(()) - } else { - // Currently we handle transposition by explicitly precomputing transposed matrices - // and calling the operation again without transposition - // TODO: At least use workspaces to allow control of allocations. Maybe - // consider implementing certain patterns (like A^T * B) explicitly - let (a, b) = { - use Cow::*; - match (trans_a, trans_b) { - (Transpose(false), Transpose(false)) => unreachable!(), - (Transpose(true), Transpose(false)) => (Owned(a.transpose()), Borrowed(b)), - (Transpose(false), Transpose(true)) => (Borrowed(a), Owned(b.transpose())), - (Transpose(true), Transpose(true)) => (Owned(a.transpose()), Owned(b.transpose())) - } - }; + Ok(()) + }, + _ => { + // Currently we handle transposition by explicitly precomputing transposed matrices + // and calling the operation again without transposition + // TODO: At least use workspaces to allow control of allocations. Maybe + // consider implementing certain patterns (like A^T * B) explicitly + let a_ref: &CsrMatrix = a.inner_ref(); + let b_ref: &CsrMatrix = b.inner_ref(); + let (a, b) = { + use Cow::*; + match (&a, &b) { + (NoOp(_), NoOp(_)) => unreachable!(), + (Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)), + (NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())), + (Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose())) + } + }; - spmm_csr(c, beta, alpha, Transpose(false), a.as_ref(), Transpose(false), b.as_ref()) + spmm_csr(c, beta, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) + } } } diff --git a/nalgebra-sparse/src/ops/serial/mod.rs b/nalgebra-sparse/src/ops/serial/mod.rs index 8ac22ac8..a58ba9a3 100644 --- a/nalgebra-sparse/src/ops/serial/mod.rs +++ b/nalgebra-sparse/src/ops/serial/mod.rs @@ -2,46 +2,47 @@ #[macro_use] macro_rules! assert_compatible_spmm_dims { - ($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => { - use crate::ops::Transpose; - match ($trans_a, $trans_b) { - (Transpose(false), Transpose(false)) => { - assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); - assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); - assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()"); - }, - (Transpose(true), Transpose(false)) => { - assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); - assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); - assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()"); - }, - (Transpose(false), Transpose(true)) => { - assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); - assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); - assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()"); - }, - (Transpose(true), Transpose(true)) => { - assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); - assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); - assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()"); + ($c:expr, $a:expr, $b:expr) => { + { + use crate::ops::Op::{NoOp, Transpose}; + match (&$a, &$b) { + (NoOp(ref a), NoOp(ref b)) => { + assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()"); + assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()"); + assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()"); + }, + (Transpose(ref a), NoOp(ref b)) => { + assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()"); + assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()"); + assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()"); + }, + (NoOp(ref a), Transpose(ref b)) => { + assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()"); + assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()"); + assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()"); + }, + (Transpose(ref a), Transpose(ref b)) => { + assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()"); + assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()"); + assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()"); + } } } - } } #[macro_use] macro_rules! assert_compatible_spadd_dims { - ($c:expr, $a:expr, $trans_a:expr) => { - use crate::ops::Transpose; - match $trans_a { - Transpose(false) => { - assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); - assert_eq!($c.ncols(), $a.ncols(), "C.ncols() != A.ncols()"); + ($c:expr, $a:expr) => { + use crate::ops::Op; + match $a { + Op::NoOp(a) => { + assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()"); + assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()"); }, - Transpose(true) => { - assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); - assert_eq!($c.ncols(), $a.nrows(), "C.ncols() != A.nrows()"); + Op::Transpose(a) => { + assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()"); + assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()"); } } diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 16482171..55953491 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,7 +1,7 @@ use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY}; use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr}; -use nalgebra_sparse::ops::{Transpose}; +use nalgebra_sparse::ops::{Op}; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::{csr, sparsity_pattern}; use nalgebra_sparse::pattern::SparsityPattern; @@ -28,10 +28,8 @@ struct SpmmCsrDenseArgs { c: DMatrix, beta: T, alpha: T, - trans_a: Transpose, - a: CsrMatrix, - trans_b: Transpose, - b: DMatrix, + a: Op>, + b: Op>, } /// Returns matrices C, A and B with compatible dimensions such that it can be used @@ -48,10 +46,10 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy> (c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone()) .prop_flat_map(move |(c, common_dim, trans_a, trans_b)| { let a_shape = - if trans_a.to_bool() { (common_dim, c.nrows()) } + if trans_a { (common_dim, c.nrows()) } else { (c.nrows(), common_dim) }; let b_shape = - if trans_b.to_bool() { (c.ncols(), common_dim) } + if trans_b { (c.ncols(), common_dim) } else { (common_dim, c.ncols()) }; let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz); let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); @@ -66,10 +64,8 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy> c, beta, alpha, - trans_a, - a, - trans_b, - b, + a: if trans_a { Op::Transpose(a) } else { Op::NoOp(a) }, + b: if trans_b { Op::Transpose(b) } else { Op::NoOp(b) }, } }) } @@ -79,14 +75,13 @@ struct SpaddCsrArgs { c: CsrMatrix, beta: T, alpha: T, - trans_a: Transpose, - a: CsrMatrix, + a: Op>, } fn spadd_csr_args_strategy() -> impl Strategy> { let value_strategy = PROPTEST_I32_VALUE_STRATEGY; - spadd_build_pattern_strategy() + spadd_pattern_strategy() .prop_flat_map(move |(a_pattern, b_pattern)| { let c_pattern = spadd_pattern(&a_pattern, &b_pattern); @@ -99,8 +94,8 @@ fn spadd_csr_args_strategy() -> impl Strategy> { let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap(); let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap(); - let a = if trans_a.to_bool() { a.transpose() } else { a }; - SpaddCsrArgs { c, beta, alpha, trans_a, a } + let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) }; + SpaddCsrArgs { c, beta, alpha, a } }) } @@ -108,8 +103,20 @@ fn dense_strategy() -> impl Strategy> { matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM) } -fn trans_strategy() -> impl Strategy + Clone { - proptest::bool::ANY.prop_map(Transpose) +fn trans_strategy() -> impl Strategy + Clone { + proptest::bool::ANY +} + +/// Wraps the values of the given strategy in `Op`, producing both transposed and non-transposed +/// values. +fn op_strategy(strategy: S) -> impl Strategy> { + let is_transposed = proptest::bool::ANY; + (strategy, is_transposed) + .prop_map(|(obj, is_trans)| if is_trans { + Op::Transpose(obj) + } else { + Op::NoOp(obj) + }) } fn pattern_strategy() -> impl Strategy { @@ -117,7 +124,7 @@ fn pattern_strategy() -> impl Strategy { } /// Constructs pairs (a, b) where a and b have the same dimensions -fn spadd_build_pattern_strategy() -> impl Strategy { +fn spadd_pattern_strategy() -> impl Strategy { pattern_strategy() .prop_flat_map(|a| { let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ); @@ -139,10 +146,8 @@ struct SpmmCsrArgs { c: CsrMatrix, beta: T, alpha: T, - trans_a: Transpose, - a: CsrMatrix, - trans_b: Transpose, - b: CsrMatrix + a: Op>, + b: Op>, } fn spmm_csr_args_strategy() -> impl Strategy> { @@ -170,10 +175,8 @@ fn spmm_csr_args_strategy() -> impl Strategy> { c, beta, alpha, - trans_a, - a: if trans_a.to_bool() { a.transpose() } else { a }, - trans_b, - b: if trans_b.to_bool() { b.transpose() } else { b } + a: if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) }, + b: if trans_b { Op::Transpose(b.transpose()) } else { Op::NoOp(b) } } }) } @@ -182,52 +185,67 @@ fn spmm_csr_args_strategy() -> impl Strategy> { fn dense_gemm<'a>(c: impl Into>, beta: i32, alpha: i32, - trans_a: Transpose, - a: impl Into>, - trans_b: Transpose, - b: impl Into>) + a: Op>>, + b: Op>>) { let mut c = c.into(); - let a = a.into(); - let b = b.into(); + let a = a.convert(); + let b = b.convert(); - match (trans_a, trans_b) { - (Transpose(false), Transpose(false)) => c.gemm(alpha, &a, &b, beta), - (Transpose(true), Transpose(false)) => c.gemm(alpha, &a.transpose(), &b, beta), - (Transpose(false), Transpose(true)) => c.gemm(alpha, &a, &b.transpose(), beta), - (Transpose(true), Transpose(true)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) - }; + use Op::{NoOp, Transpose}; + match (a, b) { + (NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta), + (Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta), + (NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta), + (Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) + } } proptest! { #[test] fn spmm_csr_dense_agrees_with_dense_result( - SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b } + SpmmCsrDenseArgs { c, beta, alpha, a, b } in spmm_csr_dense_args_strategy() ) { let mut spmm_result = c.clone(); - spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b); + spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()); let mut gemm_result = c.clone(); - dense_gemm(&mut gemm_result, beta, alpha, trans_a, &DMatrix::from(&a), trans_b, &b); + let a_dense = a.map_same_op(|a| DMatrix::from(&a)); + dense_gemm(&mut gemm_result, beta, alpha, a_dense.as_ref(), b.as_ref()); prop_assert_eq!(spmm_result, gemm_result); } #[test] fn spmm_csr_dense_panics_on_dim_mismatch( - (alpha, beta, c, a, b, trans_a, trans_b) - in (-5 ..= 5, -5 ..= 5, dense_strategy(), csr_strategy(), - dense_strategy(), trans_strategy(), trans_strategy()) + (alpha, beta, c, a, b) + in (PROPTEST_I32_VALUE_STRATEGY, + PROPTEST_I32_VALUE_STRATEGY, + dense_strategy(), + op_strategy(csr_strategy()), + op_strategy(dense_strategy())) ) { // We refer to `A * B` as the "product" - let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() }; - let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() }; + let product_rows = match &a { + Op::NoOp(ref a) => a.nrows(), + Op::Transpose(ref a) => a.ncols(), + }; + let product_cols = match &b { + Op::NoOp(ref b) => b.ncols(), + Op::Transpose(ref b) => b.nrows(), + }; // Determine the common dimension in the product // from the perspective of a and b, respectively - let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() }; - let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() }; + let product_a_common = match &a { + Op::NoOp(ref a) => a.ncols(), + Op::Transpose(ref a) => a.nrows(), + }; + let product_b_common = match &b { + Op::NoOp(ref b) => b.nrows(), + Op::Transpose(ref b) => b.ncols() + }; let dims_are_compatible = product_rows == c.nrows() && product_cols == c.ncols() @@ -239,7 +257,7 @@ proptest! { let result = catch_unwind(|| { let mut spmm_result = c.clone(); - spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b); + spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()); }); prop_assert!(result.is_err(), @@ -247,7 +265,7 @@ proptest! { } #[test] - fn spadd_pattern_test((a, b) in spadd_build_pattern_strategy()) + fn spadd_pattern_test((a, b) in spadd_pattern_strategy()) { // (a, b) are dimensionally compatible patterns let pattern_result = spadd_pattern(&a, &b); @@ -269,16 +287,18 @@ proptest! { } #[test] - fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, trans_a, a } in spadd_csr_args_strategy()) { + fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, a } in spadd_csr_args_strategy()) { // Test that we get the expected result by comparing to an equivalent dense operation // (here we give in the C matrix, so the sparsity pattern is essentially fixed) let mut c_sparse = c.clone(); - spadd_csr(&mut c_sparse, beta, alpha, trans_a, &a).unwrap(); + spadd_csr(&mut c_sparse, beta, alpha, a.as_ref()).unwrap(); let mut c_dense = DMatrix::from(&c); - let op_a_dense = DMatrix::from(&a); - let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense }; + let op_a_dense = match a { + Op::NoOp(a) => DMatrix::from(&a), + Op::Transpose(a) => DMatrix::from(&a).transpose(), + }; c_dense = beta * c_dense + alpha * &op_a_dense; prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense); @@ -343,19 +363,23 @@ proptest! { } #[test] - fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b } + fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, a, b } in spmm_csr_args_strategy() ) { // Test that we get the expected result by comparing to an equivalent dense operation // (here we give in the C matrix, so the sparsity pattern is essentially fixed) let mut c_sparse = c.clone(); - spmm_csr(&mut c_sparse, beta, alpha, trans_a, &a, trans_b, &b).unwrap(); + spmm_csr(&mut c_sparse, beta, alpha, a.as_ref(), b.as_ref()).unwrap(); let mut c_dense = DMatrix::from(&c); - let op_a_dense = DMatrix::from(&a); - let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense }; - let op_b_dense = DMatrix::from(&b); - let op_b_dense = if trans_b.to_bool() { op_b_dense.transpose() } else { op_b_dense }; + let op_a_dense = match a { + Op::NoOp(ref a) => DMatrix::from(a), + Op::Transpose(ref a) => DMatrix::from(a).transpose(), + }; + let op_b_dense = match b { + Op::NoOp(ref b) => DMatrix::from(b), + Op::Transpose(ref b) => DMatrix::from(b).transpose(), + }; c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense; prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense); @@ -363,22 +387,32 @@ proptest! { #[test] fn spmm_csr_panics_on_dim_mismatch( - (alpha, beta, c, a, b, trans_a, trans_b) + (alpha, beta, c, a, b) in (PROPTEST_I32_VALUE_STRATEGY, PROPTEST_I32_VALUE_STRATEGY, csr_strategy(), - csr_strategy(), - csr_strategy(), - trans_strategy(), - trans_strategy()) + op_strategy(csr_strategy()), + op_strategy(csr_strategy())) ) { // We refer to `A * B` as the "product" - let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() }; - let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() }; + let product_rows = match &a { + Op::NoOp(ref a) => a.nrows(), + Op::Transpose(ref a) => a.ncols(), + }; + let product_cols = match &b { + Op::NoOp(ref b) => b.ncols(), + Op::Transpose(ref b) => b.nrows(), + }; // Determine the common dimension in the product // from the perspective of a and b, respectively - let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() }; - let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() }; + let product_a_common = match &a { + Op::NoOp(ref a) => a.ncols(), + Op::Transpose(ref a) => a.nrows(), + }; + let product_b_common = match &b { + Op::NoOp(ref b) => b.nrows(), + Op::Transpose(ref b) => b.ncols(), + }; let dims_are_compatible = product_rows == c.nrows() && product_cols == c.ncols() @@ -390,7 +424,7 @@ proptest! { let result = catch_unwind(|| { let mut spmm_result = c.clone(); - spmm_csr(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b).unwrap(); + spmm_csr(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()).unwrap(); }); prop_assert!(result.is_err(), @@ -399,15 +433,20 @@ proptest! { #[test] fn spadd_csr_panics_on_dim_mismatch( - (alpha, beta, c, a, trans_a) + (alpha, beta, c, op_a) in (PROPTEST_I32_VALUE_STRATEGY, PROPTEST_I32_VALUE_STRATEGY, csr_strategy(), - csr_strategy(), - trans_strategy()) + op_strategy(csr_strategy())) ) { - let op_a_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() }; - let op_a_cols = if trans_a.to_bool() { a.nrows() } else { a.ncols() }; + let op_a_rows = match &op_a { + &Op::NoOp(ref a) => a.nrows(), + &Op::Transpose(ref a) => a.ncols() + }; + let op_a_cols = match &op_a { + &Op::NoOp(ref a) => a.ncols(), + &Op::Transpose(ref a) => a.nrows() + }; let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols; @@ -417,7 +456,7 @@ proptest! { let result = catch_unwind(|| { let mut spmm_result = c.clone(); - spadd_csr(&mut spmm_result, beta, alpha, trans_a, &a).unwrap(); + spadd_csr(&mut spmm_result, beta, alpha, op_a.as_ref()).unwrap(); }); prop_assert!(result.is_err(),