Simplify transposition API in spmm_csr_dense

This commit is contained in:
Andreas Longva 2020-12-02 17:04:19 +01:00
parent 1ae03d9ebb
commit 7c68950614
4 changed files with 35 additions and 56 deletions

View File

@ -4,31 +4,11 @@ pub mod serial;
/// TODO /// TODO
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Transposition { pub struct Transpose(pub bool);
/// TODO
Transpose,
/// TODO
NoTranspose,
}
impl Transposition { impl Transpose {
/// TODO /// TODO
pub fn is_transpose(&self) -> bool { pub fn to_bool(&self) -> bool {
self == &Self::Transpose self.0
}
/// TODO
pub fn from_bool(transpose: bool) -> Self {
if transpose { Self::Transpose } else { Self::NoTranspose }
} }
} }
/// TODO
pub fn transpose() -> Transposition {
Transposition::Transpose
}
/// TODO
pub fn no_transpose() -> Transposition {
Transposition::NoTranspose
}

View File

@ -1,15 +1,15 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::ops::Transposition; use crate::ops::{Transpose};
use nalgebra::{DVectorSlice, Scalar, DMatrixSlice, DVectorSliceMut, ClosedAdd, ClosedMul, DMatrixSliceMut}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
/// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`. /// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`.
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>, pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
beta: T, beta: T,
alpha: T, alpha: T,
trans_a: Transposition, trans_a: Transpose,
a: &CsrMatrix<T>, a: &CsrMatrix<T>,
trans_b: Transposition, trans_b: Transpose,
b: impl Into<DMatrixSlice<'a, T>>) b: impl Into<DMatrixSlice<'a, T>>)
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
@ -20,16 +20,16 @@ pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>, fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>,
beta: T, beta: T,
alpha: T, alpha: T,
trans_a: Transposition, trans_a: Transpose,
a: &CsrMatrix<T>, a: &CsrMatrix<T>,
trans_b: Transposition, trans_b: Transpose,
b: DMatrixSlice<T>) b: DMatrixSlice<T>)
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b); assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
if trans_a.is_transpose() { if trans_a.to_bool() {
// In this case, we have to pre-multiply C by beta // In this case, we have to pre-multiply C by beta
c *= beta; c *= beta;
@ -38,7 +38,7 @@ where
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) { for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone(); let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
let mut c_row_i = c.row_mut(i); let mut c_row_i = c.row_mut(i);
if trans_b.is_transpose() { if trans_b.to_bool() {
let b_col_k = b.column(k); let b_col_k = b.column(k);
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) { for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone(); *c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
@ -58,7 +58,7 @@ where
let mut dot_ij = T::zero(); let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) { for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let b_contrib = let b_contrib =
if trans_b.is_transpose() { b.index((j, k)) } else { b.index((k, j)) }; if trans_b.to_bool() { b.index((j, k)) } else { b.index((k, j)) };
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone(); dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
} }
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij; *c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;

View File

@ -3,24 +3,24 @@
#[macro_use] #[macro_use]
macro_rules! assert_compatible_spmm_dims { macro_rules! assert_compatible_spmm_dims {
($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => { ($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => {
use crate::ops::Transposition::{Transpose, NoTranspose}; use crate::ops::Transpose;
match ($trans_a, $trans_b) { match ($trans_a, $trans_b) {
(NoTranspose, NoTranspose) => { (Transpose(false), Transpose(false)) => {
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()");
assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()"); assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()");
}, },
(Transpose, NoTranspose) => { (Transpose(true), Transpose(false)) => {
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()"); assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()");
assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()"); assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()");
}, },
(NoTranspose, Transpose) => { (Transpose(false), Transpose(true)) => {
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()");
assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()"); assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()");
}, },
(Transpose, Transpose) => { (Transpose(true), Transpose(true)) => {
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()"); assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()");
assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()"); assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()");

View File

@ -1,6 +1,6 @@
use nalgebra_sparse::coo::CooMatrix; use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense}; use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense};
use nalgebra_sparse::ops::{no_transpose, Transposition}; use nalgebra_sparse::ops::{Transpose};
use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::csr; use nalgebra_sparse::proptest::csr;
@ -41,9 +41,9 @@ struct SpmmCsrDenseArgs<T: Scalar> {
c: DMatrix<T>, c: DMatrix<T>,
beta: T, beta: T,
alpha: T, alpha: T,
trans_a: Transposition, trans_a: Transpose,
a: CsrMatrix<T>, a: CsrMatrix<T>,
trans_b: Transposition, trans_b: Transpose,
b: DMatrix<T>, b: DMatrix<T>,
} }
@ -61,10 +61,10 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone()) (c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| { .prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
let a_shape = let a_shape =
if trans_a.is_transpose() { (common_dim, c.nrows()) } if trans_a.to_bool() { (common_dim, c.nrows()) }
else { (c.nrows(), common_dim) }; else { (c.nrows(), common_dim) };
let b_shape = let b_shape =
if trans_b.is_transpose() { (c.ncols(), common_dim) } if trans_b.to_bool() { (c.ncols(), common_dim) }
else { (common_dim, c.ncols()) }; else { (common_dim, c.ncols()) };
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz); let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz);
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1); let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
@ -95,29 +95,28 @@ fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6) matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
} }
fn trans_strategy() -> impl Strategy<Value=Transposition> + Clone { fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
proptest::bool::ANY.prop_map(Transposition::from_bool) proptest::bool::ANY.prop_map(Transpose)
} }
/// Helper function to help us call dense GEMM with our transposition parameters /// Helper function to help us call dense GEMM with our transposition parameters
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>, fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
beta: i32, beta: i32,
alpha: i32, alpha: i32,
trans_a: Transposition, trans_a: Transpose,
a: impl Into<DMatrixSlice<'a, i32>>, a: impl Into<DMatrixSlice<'a, i32>>,
trans_b: Transposition, trans_b: Transpose,
b: impl Into<DMatrixSlice<'a, i32>>) b: impl Into<DMatrixSlice<'a, i32>>)
{ {
let mut c = c.into(); let mut c = c.into();
let a = a.into(); let a = a.into();
let b = b.into(); let b = b.into();
use Transposition::{Transpose, NoTranspose};
match (trans_a, trans_b) { match (trans_a, trans_b) {
(NoTranspose, NoTranspose) => c.gemm(alpha, &a, &b, beta), (Transpose(false), Transpose(false)) => c.gemm(alpha, &a, &b, beta),
(Transpose, NoTranspose) => c.gemm(alpha, &a.transpose(), &b, beta), (Transpose(true), Transpose(false)) => c.gemm(alpha, &a.transpose(), &b, beta),
(NoTranspose, Transpose) => c.gemm(alpha, &a, &b.transpose(), beta), (Transpose(false), Transpose(true)) => c.gemm(alpha, &a, &b.transpose(), beta),
(Transpose, Transpose) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta) (Transpose(true), Transpose(true)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
}; };
} }
@ -144,12 +143,12 @@ proptest! {
dense_strategy(), trans_strategy(), trans_strategy()) dense_strategy(), trans_strategy(), trans_strategy())
) { ) {
// We refer to `A * B` as the "product" // We refer to `A * B` as the "product"
let product_rows = if trans_a.is_transpose() { a.ncols() } else { a.nrows() }; let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
let product_cols = if trans_b.is_transpose() { b.nrows() } else { b.ncols() }; let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() };
// Determine the common dimension in the product // Determine the common dimension in the product
// from the perspective of a and b, respectively // from the perspective of a and b, respectively
let product_a_common = if trans_a.is_transpose() { a.nrows() } else { a.ncols() }; let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
let product_b_common = if trans_b.is_transpose() { b.ncols() } else { b.nrows() }; let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() };
let dims_are_compatible = product_rows == c.nrows() let dims_are_compatible = product_rows == c.nrows()
&& product_cols == c.ncols() && product_cols == c.ncols()