Reorder parameters in ops to intuitive order

This commit is contained in:
Andreas Longva 2020-12-21 15:42:32 +01:00
parent 061024ab1f
commit 4af3fcbdd3
3 changed files with 25 additions and 25 deletions

View File

@ -21,8 +21,8 @@ where
// We are giving data that is valid by definition, so it is safe to unwrap below // We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
.unwrap(); .unwrap();
spadd_csr(&mut result, T::zero(), T::one(), Op::NoOp(&self)).unwrap(); spadd_csr(T::zero(), &mut result, T::one(), Op::NoOp(&self)).unwrap();
spadd_csr(&mut result, T::one(), T::one(), Op::NoOp(&rhs)).unwrap(); spadd_csr(T::one(), &mut result, T::one(), Op::NoOp(&rhs)).unwrap();
result result
} }
} }
@ -35,7 +35,7 @@ where
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output { fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
if Arc::ptr_eq(self.pattern(), rhs.pattern()) { if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
spadd_csr(&mut self, T::one(), T::one(), Op::NoOp(rhs)).unwrap(); spadd_csr(T::one(), &mut self, T::one(), Op::NoOp(rhs)).unwrap();
self self
} else { } else {
&self + rhs &self + rhs
@ -90,8 +90,8 @@ impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T>
let values = vec![T::zero(); pattern.nnz()]; let values = vec![T::zero(); pattern.nnz()];
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
.unwrap(); .unwrap();
spmm_csr(&mut result, spmm_csr(T::zero(),
T::zero(), &mut result,
T::one(), T::one(),
Op::NoOp(a), Op::NoOp(a),
Op::NoOp(b)) Op::NoOp(b))

View File

@ -8,8 +8,8 @@ use std::sync::Arc;
use std::borrow::Cow; use std::borrow::Cow;
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`. /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>, pub fn spmm_csr_dense<'a, T>(beta: T,
beta: T, c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>) b: Op<impl Into<DMatrixSlice<'a, T>>>)
@ -17,11 +17,11 @@ pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
let b = b.convert(); let b = b.convert();
spmm_csr_dense_(c.into(), beta, alpha, a, b) spmm_csr_dense_(beta, c.into(), alpha, a, b)
} }
fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>, fn spmm_csr_dense_<T>(beta: T,
beta: T, mut c: DMatrixSliceMut<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<DMatrixSlice<T>>) b: Op<DMatrixSlice<T>>)
@ -87,8 +87,8 @@ fn spadd_csr_unexpected_entry() -> OperationError {
/// ///
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is /// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
/// returned. /// returned.
pub fn spadd_csr<T>(c: &mut CsrMatrix<T>, pub fn spadd_csr<T>(beta: T,
beta: T, c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>) a: Op<&CsrMatrix<T>>)
-> Result<(), OperationError> -> Result<(), OperationError>
@ -161,9 +161,9 @@ fn spmm_csr_unexpected_entry() -> OperationError {
} }
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`. /// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csr<'a, T>( pub fn spmm_csr<T>(
c: &mut CsrMatrix<T>,
beta: T, beta: T,
c: &mut CsrMatrix<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<&CsrMatrix<T>>) b: Op<&CsrMatrix<T>>)
@ -218,7 +218,7 @@ where
} }
}; };
spmm_csr(c, beta, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) spmm_csr(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
} }
} }
} }

View File

@ -181,9 +181,9 @@ fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
}) })
} }
/// Helper function to help us call dense GEMM with our transposition parameters /// Helper function to help us call dense GEMM with our `Op` type
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>, fn dense_gemm<'a>(beta: i32,
beta: i32, c: impl Into<DMatrixSliceMut<'a, i32>>,
alpha: i32, alpha: i32,
a: Op<impl Into<DMatrixSlice<'a, i32>>>, a: Op<impl Into<DMatrixSlice<'a, i32>>>,
b: Op<impl Into<DMatrixSlice<'a, i32>>>) b: Op<impl Into<DMatrixSlice<'a, i32>>>)
@ -209,11 +209,11 @@ proptest! {
in spmm_csr_dense_args_strategy() in spmm_csr_dense_args_strategy()
) { ) {
let mut spmm_result = c.clone(); let mut spmm_result = c.clone();
spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()); spmm_csr_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref());
let mut gemm_result = c.clone(); let mut gemm_result = c.clone();
let a_dense = a.map_same_op(|a| DMatrix::from(&a)); let a_dense = a.map_same_op(|a| DMatrix::from(&a));
dense_gemm(&mut gemm_result, beta, alpha, a_dense.as_ref(), b.as_ref()); dense_gemm(beta, &mut gemm_result, alpha, a_dense.as_ref(), b.as_ref());
prop_assert_eq!(spmm_result, gemm_result); prop_assert_eq!(spmm_result, gemm_result);
} }
@ -257,7 +257,7 @@ proptest! {
let result = catch_unwind(|| { let result = catch_unwind(|| {
let mut spmm_result = c.clone(); let mut spmm_result = c.clone();
spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()); spmm_csr_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref());
}); });
prop_assert!(result.is_err(), prop_assert!(result.is_err(),
@ -292,7 +292,7 @@ proptest! {
// (here we give in the C matrix, so the sparsity pattern is essentially fixed) // (here we give in the C matrix, so the sparsity pattern is essentially fixed)
let mut c_sparse = c.clone(); let mut c_sparse = c.clone();
spadd_csr(&mut c_sparse, beta, alpha, a.as_ref()).unwrap(); spadd_csr(beta, &mut c_sparse, alpha, a.as_ref()).unwrap();
let mut c_dense = DMatrix::from(&c); let mut c_dense = DMatrix::from(&c);
let op_a_dense = match a { let op_a_dense = match a {
@ -369,7 +369,7 @@ proptest! {
// Test that we get the expected result by comparing to an equivalent dense operation // Test that we get the expected result by comparing to an equivalent dense operation
// (here we give in the C matrix, so the sparsity pattern is essentially fixed) // (here we give in the C matrix, so the sparsity pattern is essentially fixed)
let mut c_sparse = c.clone(); let mut c_sparse = c.clone();
spmm_csr(&mut c_sparse, beta, alpha, a.as_ref(), b.as_ref()).unwrap(); spmm_csr(beta, &mut c_sparse, alpha, a.as_ref(), b.as_ref()).unwrap();
let mut c_dense = DMatrix::from(&c); let mut c_dense = DMatrix::from(&c);
let op_a_dense = match a { let op_a_dense = match a {
@ -424,7 +424,7 @@ proptest! {
let result = catch_unwind(|| { let result = catch_unwind(|| {
let mut spmm_result = c.clone(); let mut spmm_result = c.clone();
spmm_csr(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()).unwrap(); spmm_csr(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()).unwrap();
}); });
prop_assert!(result.is_err(), prop_assert!(result.is_err(),
@ -456,7 +456,7 @@ proptest! {
let result = catch_unwind(|| { let result = catch_unwind(|| {
let mut spmm_result = c.clone(); let mut spmm_result = c.clone();
spadd_csr(&mut spmm_result, beta, alpha, op_a.as_ref()).unwrap(); spadd_csr(beta, &mut spmm_result, alpha, op_a.as_ref()).unwrap();
}); });
prop_assert!(result.is_err(), prop_assert!(result.is_err(),