From 4af3fcbdd3e93cfa93a7d842ad40b8761e830452 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Mon, 21 Dec 2020 15:42:32 +0100 Subject: [PATCH] Reorder parameters in ops to intuitive order --- nalgebra-sparse/src/ops/impl_std_ops.rs | 10 +++++----- nalgebra-sparse/src/ops/serial/csr.rs | 20 ++++++++++---------- nalgebra-sparse/tests/unit_tests/ops.rs | 20 ++++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index c181464d..92973e34 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -21,8 +21,8 @@ where // We are giving data that is valid by definition, so it is safe to unwrap below let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) .unwrap(); - spadd_csr(&mut result, T::zero(), T::one(), Op::NoOp(&self)).unwrap(); - spadd_csr(&mut result, T::one(), T::one(), Op::NoOp(&rhs)).unwrap(); + spadd_csr(T::zero(), &mut result, T::one(), Op::NoOp(&self)).unwrap(); + spadd_csr(T::one(), &mut result, T::one(), Op::NoOp(&rhs)).unwrap(); result } } @@ -35,7 +35,7 @@ where fn add(mut self, rhs: &'a CsrMatrix) -> Self::Output { if Arc::ptr_eq(self.pattern(), rhs.pattern()) { - spadd_csr(&mut self, T::one(), T::one(), Op::NoOp(rhs)).unwrap(); + spadd_csr(T::one(), &mut self, T::one(), Op::NoOp(rhs)).unwrap(); self } else { &self + rhs @@ -90,8 +90,8 @@ impl_matrix_mul!(<'a>(a: &'a CsrMatrix, b: &'a CsrMatrix) -> CsrMatrix let values = vec![T::zero(); pattern.nnz()]; let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) .unwrap(); - spmm_csr(&mut result, - T::zero(), + spmm_csr(T::zero(), + &mut result, T::one(), Op::NoOp(a), Op::NoOp(b)) diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index e1b5a1c5..88284114 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -8,8 +8,8 @@ use std::sync::Arc; use std::borrow::Cow; /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`. -pub fn spmm_csr_dense<'a, T>(c: impl Into>, - beta: T, +pub fn spmm_csr_dense<'a, T>(beta: T, + c: impl Into>, alpha: T, a: Op<&CsrMatrix>, b: Op>>) @@ -17,11 +17,11 @@ pub fn spmm_csr_dense<'a, T>(c: impl Into>, T: Scalar + ClosedAdd + ClosedMul + Zero + One { let b = b.convert(); - spmm_csr_dense_(c.into(), beta, alpha, a, b) + spmm_csr_dense_(beta, c.into(), alpha, a, b) } -fn spmm_csr_dense_(mut c: DMatrixSliceMut, - beta: T, +fn spmm_csr_dense_(beta: T, + mut c: DMatrixSliceMut, alpha: T, a: Op<&CsrMatrix>, b: Op>) @@ -87,8 +87,8 @@ fn spadd_csr_unexpected_entry() -> OperationError { /// /// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is /// returned. -pub fn spadd_csr(c: &mut CsrMatrix, - beta: T, +pub fn spadd_csr(beta: T, + c: &mut CsrMatrix, alpha: T, a: Op<&CsrMatrix>) -> Result<(), OperationError> @@ -161,9 +161,9 @@ fn spmm_csr_unexpected_entry() -> OperationError { } /// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`. -pub fn spmm_csr<'a, T>( - c: &mut CsrMatrix, +pub fn spmm_csr( beta: T, + c: &mut CsrMatrix, alpha: T, a: Op<&CsrMatrix>, b: Op<&CsrMatrix>) @@ -218,7 +218,7 @@ where } }; - spmm_csr(c, beta, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) + spmm_csr(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref())) } } } diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 55953491..c341739d 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -181,9 +181,9 @@ fn spmm_csr_args_strategy() -> impl Strategy> { }) } -/// Helper function to help us call dense GEMM with our transposition parameters -fn dense_gemm<'a>(c: impl Into>, - beta: i32, +/// Helper function to help us call dense GEMM with our `Op` type +fn dense_gemm<'a>(beta: i32, + c: impl Into>, alpha: i32, a: Op>>, b: Op>>) @@ -209,11 +209,11 @@ proptest! { in spmm_csr_dense_args_strategy() ) { let mut spmm_result = c.clone(); - spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()); + spmm_csr_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()); let mut gemm_result = c.clone(); let a_dense = a.map_same_op(|a| DMatrix::from(&a)); - dense_gemm(&mut gemm_result, beta, alpha, a_dense.as_ref(), b.as_ref()); + dense_gemm(beta, &mut gemm_result, alpha, a_dense.as_ref(), b.as_ref()); prop_assert_eq!(spmm_result, gemm_result); } @@ -257,7 +257,7 @@ proptest! { let result = catch_unwind(|| { let mut spmm_result = c.clone(); - spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()); + spmm_csr_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()); }); prop_assert!(result.is_err(), @@ -292,7 +292,7 @@ proptest! { // (here we give in the C matrix, so the sparsity pattern is essentially fixed) let mut c_sparse = c.clone(); - spadd_csr(&mut c_sparse, beta, alpha, a.as_ref()).unwrap(); + spadd_csr(beta, &mut c_sparse, alpha, a.as_ref()).unwrap(); let mut c_dense = DMatrix::from(&c); let op_a_dense = match a { @@ -369,7 +369,7 @@ proptest! { // Test that we get the expected result by comparing to an equivalent dense operation // (here we give in the C matrix, so the sparsity pattern is essentially fixed) let mut c_sparse = c.clone(); - spmm_csr(&mut c_sparse, beta, alpha, a.as_ref(), b.as_ref()).unwrap(); + spmm_csr(beta, &mut c_sparse, alpha, a.as_ref(), b.as_ref()).unwrap(); let mut c_dense = DMatrix::from(&c); let op_a_dense = match a { @@ -424,7 +424,7 @@ proptest! { let result = catch_unwind(|| { let mut spmm_result = c.clone(); - spmm_csr(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()).unwrap(); + spmm_csr(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()).unwrap(); }); prop_assert!(result.is_err(), @@ -456,7 +456,7 @@ proptest! { let result = catch_unwind(|| { let mut spmm_result = c.clone(); - spadd_csr(&mut spmm_result, beta, alpha, op_a.as_ref()).unwrap(); + spadd_csr(beta, &mut spmm_result, alpha, op_a.as_ref()).unwrap(); }); prop_assert!(result.is_err(),