From 1ae03d9ebbef551f414db250f767e9affef3c5bd Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Wed, 2 Dec 2020 16:56:22 +0100 Subject: [PATCH] Implement spmm_csr_dense --- nalgebra-sparse/src/ops/mod.rs | 34 ++++ .../src/{ops.rs => ops/serial/coo.rs} | 2 + nalgebra-sparse/src/ops/serial/csr.rs | 68 ++++++++ nalgebra-sparse/src/ops/serial/mod.rs | 37 +++++ nalgebra-sparse/tests/unit_tests/ops.rs | 147 +++++++++++++++++- 5 files changed, 286 insertions(+), 2 deletions(-) create mode 100644 nalgebra-sparse/src/ops/mod.rs rename nalgebra-sparse/src/{ops.rs => ops/serial/coo.rs} (98%) create mode 100644 nalgebra-sparse/src/ops/serial/csr.rs create mode 100644 nalgebra-sparse/src/ops/serial/mod.rs diff --git a/nalgebra-sparse/src/ops/mod.rs b/nalgebra-sparse/src/ops/mod.rs new file mode 100644 index 00000000..eea08495 --- /dev/null +++ b/nalgebra-sparse/src/ops/mod.rs @@ -0,0 +1,34 @@ +//! TODO + +pub mod serial; + +/// TODO +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Transposition { + /// TODO + Transpose, + /// TODO + NoTranspose, +} + +impl Transposition { + /// TODO + pub fn is_transpose(&self) -> bool { + self == &Self::Transpose + } + + /// TODO + pub fn from_bool(transpose: bool) -> Self { + if transpose { Self::Transpose } else { Self::NoTranspose } + } +} + +/// TODO +pub fn transpose() -> Transposition { + Transposition::Transpose +} + +/// TODO +pub fn no_transpose() -> Transposition { + Transposition::NoTranspose +} \ No newline at end of file diff --git a/nalgebra-sparse/src/ops.rs b/nalgebra-sparse/src/ops/serial/coo.rs similarity index 98% rename from nalgebra-sparse/src/ops.rs rename to nalgebra-sparse/src/ops/serial/coo.rs index bc3b9399..322c6914 100644 --- a/nalgebra-sparse/src/ops.rs +++ b/nalgebra-sparse/src/ops/serial/coo.rs @@ -12,6 +12,8 @@ use num_traits::{One, Zero}; /// /// If `beta == 0`, the elements in `y` are never read. /// +/// TODO: Rethink this function +/// /// Panics /// ------ /// diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs new file mode 100644 index 00000000..fa7ff30f --- /dev/null +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -0,0 +1,68 @@ +use crate::csr::CsrMatrix; +use crate::ops::Transposition; +use nalgebra::{DVectorSlice, Scalar, DMatrixSlice, DVectorSliceMut, ClosedAdd, ClosedMul, DMatrixSliceMut}; +use num_traits::{Zero, One}; + +/// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`. +pub fn spmm_csr_dense<'a, T>(c: impl Into>, + beta: T, + alpha: T, + trans_a: Transposition, + a: &CsrMatrix, + trans_b: Transposition, + b: impl Into>) + where + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + spmm_csr_dense_(c.into(), beta, alpha, trans_a, a, trans_b, b.into()) +} + +fn spmm_csr_dense_(mut c: DMatrixSliceMut, + beta: T, + alpha: T, + trans_a: Transposition, + a: &CsrMatrix, + trans_b: Transposition, + b: DMatrixSlice) +where + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b); + + if trans_a.is_transpose() { + // In this case, we have to pre-multiply C by beta + c *= beta; + + for k in 0..a.nrows() { + let a_row_k = a.row(k); + for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) { + let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone(); + let mut c_row_i = c.row_mut(i); + if trans_b.is_transpose() { + let b_col_k = b.column(k); + for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) { + *c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone(); + } + } else { + let b_row_k = b.row(k); + for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) { + *c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone(); + } + } + } + } + } else { + for j in 0..c.ncols() { + let mut c_col_j = c.column_mut(j); + for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) { + let mut dot_ij = T::zero(); + for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { + let b_contrib = + if trans_b.is_transpose() { b.index((j, k)) } else { b.index((k, j)) }; + dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); + } + *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; + } + } + } +} \ No newline at end of file diff --git a/nalgebra-sparse/src/ops/serial/mod.rs b/nalgebra-sparse/src/ops/serial/mod.rs new file mode 100644 index 00000000..02e15210 --- /dev/null +++ b/nalgebra-sparse/src/ops/serial/mod.rs @@ -0,0 +1,37 @@ +//! TODO + +#[macro_use] +macro_rules! assert_compatible_spmm_dims { + ($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => { + use crate::ops::Transposition::{Transpose, NoTranspose}; + match ($trans_a, $trans_b) { + (NoTranspose, NoTranspose) => { + assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); + assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); + assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()"); + }, + (Transpose, NoTranspose) => { + assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); + assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); + assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()"); + }, + (NoTranspose, Transpose) => { + assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); + assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); + assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()"); + }, + (Transpose, Transpose) => { + assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); + assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); + assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()"); + } + } + + } +} + +mod coo; +mod csr; + +pub use coo::*; +pub use csr::*; \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 01df60b6..19add876 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,6 +1,15 @@ use nalgebra_sparse::coo::CooMatrix; -use nalgebra_sparse::ops::spmv_coo; -use nalgebra::{DVector, DMatrix}; +use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense}; +use nalgebra_sparse::ops::{no_transpose, Transposition}; +use nalgebra_sparse::csr::CsrMatrix; +use nalgebra_sparse::proptest::csr; + +use nalgebra::{DVector, DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice}; +use nalgebra::proptest::matrix; + +use proptest::prelude::*; + +use std::panic::catch_unwind; #[test] fn spmv_coo_agrees_with_dense_gemv() { @@ -25,4 +34,138 @@ fn spmv_coo_agrees_with_dense_gemv() { assert_eq!(y, y_dense); } } +} + +#[derive(Debug)] +struct SpmmCsrDenseArgs { + c: DMatrix, + beta: T, + alpha: T, + trans_a: Transposition, + a: CsrMatrix, + trans_b: Transposition, + b: DMatrix, +} + +/// Returns matrices C, A and B with compatible dimensions such that it can be used +/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`. +fn spmm_csr_dense_args_strategy() -> impl Strategy> { + let max_nnz = 40; + let value_strategy = -5 ..= 5; + let c_rows = 0 ..= 6usize; + let c_cols = 0 ..= 6usize; + let common_dim = 0 ..= 6usize; + let trans_strategy = trans_strategy(); + let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols); + + (c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone()) + .prop_flat_map(move |(c, common_dim, trans_a, trans_b)| { + let a_shape = + if trans_a.is_transpose() { (common_dim, c.nrows()) } + else { (c.nrows(), common_dim) }; + let b_shape = + if trans_b.is_transpose() { (c.ncols(), common_dim) } + else { (common_dim, c.ncols()) }; + let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz); + let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); + + // We use the same values for alpha, beta parameters as for matrix elements + let alpha = value_strategy.clone(); + let beta = value_strategy.clone(); + + (Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b) + }).prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| { + SpmmCsrDenseArgs { + c, + beta, + alpha, + trans_a, + a, + trans_b, + b, + } + }) +} + +fn csr_strategy() -> impl Strategy> { + csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40) +} + +fn dense_strategy() -> impl Strategy> { + matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6) +} + +fn trans_strategy() -> impl Strategy + Clone { + proptest::bool::ANY.prop_map(Transposition::from_bool) +} + +/// Helper function to help us call dense GEMM with our transposition parameters +fn dense_gemm<'a>(c: impl Into>, + beta: i32, + alpha: i32, + trans_a: Transposition, + a: impl Into>, + trans_b: Transposition, + b: impl Into>) +{ + let mut c = c.into(); + let a = a.into(); + let b = b.into(); + + use Transposition::{Transpose, NoTranspose}; + match (trans_a, trans_b) { + (NoTranspose, NoTranspose) => c.gemm(alpha, &a, &b, beta), + (Transpose, NoTranspose) => c.gemm(alpha, &a.transpose(), &b, beta), + (NoTranspose, Transpose) => c.gemm(alpha, &a, &b.transpose(), beta), + (Transpose, Transpose) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) + }; +} + +proptest! { + + #[test] + fn spmm_csr_dense_agrees_with_dense_result( + SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b } + in spmm_csr_dense_args_strategy() + ) { + let mut spmm_result = c.clone(); + spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b); + + let mut gemm_result = c.clone(); + dense_gemm(&mut gemm_result, beta, alpha, trans_a, &DMatrix::from(&a), trans_b, &b); + + prop_assert_eq!(spmm_result, gemm_result); + } + + #[test] + fn spmm_csr_dense_panics_on_dim_mismatch( + (alpha, beta, c, a, b, trans_a, trans_b) + in (-5 ..= 5, -5 ..= 5, dense_strategy(), csr_strategy(), + dense_strategy(), trans_strategy(), trans_strategy()) + ) { + // We refer to `A * B` as the "product" + let product_rows = if trans_a.is_transpose() { a.ncols() } else { a.nrows() }; + let product_cols = if trans_b.is_transpose() { b.nrows() } else { b.ncols() }; + // Determine the common dimension in the product + // from the perspective of a and b, respectively + let product_a_common = if trans_a.is_transpose() { a.nrows() } else { a.ncols() }; + let product_b_common = if trans_b.is_transpose() { b.ncols() } else { b.nrows() }; + + let dims_are_compatible = product_rows == c.nrows() + && product_cols == c.ncols() + && product_a_common == product_b_common; + + // If the dimensions randomly happen to be compatible, then of course we need to + // skip the test, so we assume that they are not. + prop_assume!(!dims_are_compatible); + + let result = catch_unwind(|| { + let mut spmm_result = c.clone(); + spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b); + }); + + prop_assert!(result.is_err(), + "The SPMM kernel executed successfully despite mismatch dimensions"); + } + } \ No newline at end of file