diff --git a/nalgebra-sparse/src/ops/mod.rs b/nalgebra-sparse/src/ops/mod.rs index eea08495..bf1698ec 100644 --- a/nalgebra-sparse/src/ops/mod.rs +++ b/nalgebra-sparse/src/ops/mod.rs @@ -4,31 +4,11 @@ pub mod serial; /// TODO #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum Transposition { - /// TODO - Transpose, - /// TODO - NoTranspose, -} +pub struct Transpose(pub bool); -impl Transposition { +impl Transpose { /// TODO - pub fn is_transpose(&self) -> bool { - self == &Self::Transpose + pub fn to_bool(&self) -> bool { + self.0 } - - /// TODO - pub fn from_bool(transpose: bool) -> Self { - if transpose { Self::Transpose } else { Self::NoTranspose } - } -} - -/// TODO -pub fn transpose() -> Transposition { - Transposition::Transpose -} - -/// TODO -pub fn no_transpose() -> Transposition { - Transposition::NoTranspose } \ No newline at end of file diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index fa7ff30f..3b28ac74 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -1,15 +1,15 @@ use crate::csr::CsrMatrix; -use crate::ops::Transposition; -use nalgebra::{DVectorSlice, Scalar, DMatrixSlice, DVectorSliceMut, ClosedAdd, ClosedMul, DMatrixSliceMut}; +use crate::ops::{Transpose}; +use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use num_traits::{Zero, One}; /// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`. pub fn spmm_csr_dense<'a, T>(c: impl Into>, beta: T, alpha: T, - trans_a: Transposition, + trans_a: Transpose, a: &CsrMatrix, - trans_b: Transposition, + trans_b: Transpose, b: impl Into>) where T: Scalar + ClosedAdd + ClosedMul + Zero + One @@ -20,16 +20,16 @@ pub fn spmm_csr_dense<'a, T>(c: impl Into>, fn spmm_csr_dense_(mut c: DMatrixSliceMut, beta: T, alpha: T, - trans_a: Transposition, + trans_a: Transpose, a: &CsrMatrix, - trans_b: Transposition, + trans_b: Transpose, b: DMatrixSlice) where T: Scalar + ClosedAdd + ClosedMul + Zero + One { assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b); - if trans_a.is_transpose() { + if trans_a.to_bool() { // In this case, we have to pre-multiply C by beta c *= beta; @@ -38,7 +38,7 @@ where for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) { let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone(); let mut c_row_i = c.row_mut(i); - if trans_b.is_transpose() { + if trans_b.to_bool() { let b_col_k = b.column(k); for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) { *c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone(); @@ -58,7 +58,7 @@ where let mut dot_ij = T::zero(); for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { let b_contrib = - if trans_b.is_transpose() { b.index((j, k)) } else { b.index((k, j)) }; + if trans_b.to_bool() { b.index((j, k)) } else { b.index((k, j)) }; dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); } *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; diff --git a/nalgebra-sparse/src/ops/serial/mod.rs b/nalgebra-sparse/src/ops/serial/mod.rs index 02e15210..a7615ec4 100644 --- a/nalgebra-sparse/src/ops/serial/mod.rs +++ b/nalgebra-sparse/src/ops/serial/mod.rs @@ -3,24 +3,24 @@ #[macro_use] macro_rules! assert_compatible_spmm_dims { ($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => { - use crate::ops::Transposition::{Transpose, NoTranspose}; + use crate::ops::Transpose; match ($trans_a, $trans_b) { - (NoTranspose, NoTranspose) => { + (Transpose(false), Transpose(false)) => { assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()"); }, - (Transpose, NoTranspose) => { + (Transpose(true), Transpose(false)) => { assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()"); }, - (NoTranspose, Transpose) => { + (Transpose(false), Transpose(true)) => { assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()"); }, - (Transpose, Transpose) => { + (Transpose(true), Transpose(true)) => { assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()"); diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 19add876..add03a98 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,6 +1,6 @@ use nalgebra_sparse::coo::CooMatrix; use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense}; -use nalgebra_sparse::ops::{no_transpose, Transposition}; +use nalgebra_sparse::ops::{Transpose}; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::csr; @@ -41,9 +41,9 @@ struct SpmmCsrDenseArgs { c: DMatrix, beta: T, alpha: T, - trans_a: Transposition, + trans_a: Transpose, a: CsrMatrix, - trans_b: Transposition, + trans_b: Transpose, b: DMatrix, } @@ -61,10 +61,10 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy> (c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone()) .prop_flat_map(move |(c, common_dim, trans_a, trans_b)| { let a_shape = - if trans_a.is_transpose() { (common_dim, c.nrows()) } + if trans_a.to_bool() { (common_dim, c.nrows()) } else { (c.nrows(), common_dim) }; let b_shape = - if trans_b.is_transpose() { (c.ncols(), common_dim) } + if trans_b.to_bool() { (c.ncols(), common_dim) } else { (common_dim, c.ncols()) }; let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz); let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); @@ -95,29 +95,28 @@ fn dense_strategy() -> impl Strategy> { matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6) } -fn trans_strategy() -> impl Strategy + Clone { - proptest::bool::ANY.prop_map(Transposition::from_bool) +fn trans_strategy() -> impl Strategy + Clone { + proptest::bool::ANY.prop_map(Transpose) } /// Helper function to help us call dense GEMM with our transposition parameters fn dense_gemm<'a>(c: impl Into>, beta: i32, alpha: i32, - trans_a: Transposition, + trans_a: Transpose, a: impl Into>, - trans_b: Transposition, + trans_b: Transpose, b: impl Into>) { let mut c = c.into(); let a = a.into(); let b = b.into(); - use Transposition::{Transpose, NoTranspose}; match (trans_a, trans_b) { - (NoTranspose, NoTranspose) => c.gemm(alpha, &a, &b, beta), - (Transpose, NoTranspose) => c.gemm(alpha, &a.transpose(), &b, beta), - (NoTranspose, Transpose) => c.gemm(alpha, &a, &b.transpose(), beta), - (Transpose, Transpose) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) + (Transpose(false), Transpose(false)) => c.gemm(alpha, &a, &b, beta), + (Transpose(true), Transpose(false)) => c.gemm(alpha, &a.transpose(), &b, beta), + (Transpose(false), Transpose(true)) => c.gemm(alpha, &a, &b.transpose(), beta), + (Transpose(true), Transpose(true)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) }; } @@ -144,12 +143,12 @@ proptest! { dense_strategy(), trans_strategy(), trans_strategy()) ) { // We refer to `A * B` as the "product" - let product_rows = if trans_a.is_transpose() { a.ncols() } else { a.nrows() }; - let product_cols = if trans_b.is_transpose() { b.nrows() } else { b.ncols() }; + let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() }; + let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() }; // Determine the common dimension in the product // from the perspective of a and b, respectively - let product_a_common = if trans_a.is_transpose() { a.nrows() } else { a.ncols() }; - let product_b_common = if trans_b.is_transpose() { b.ncols() } else { b.nrows() }; + let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() }; + let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() }; let dims_are_compatible = product_rows == c.nrows() && product_cols == c.ncols()