forked from M-Labs/nalgebra
Implement spmm_csr_dense
This commit is contained in:
parent
95ee65fa8e
commit
1ae03d9ebb
34
nalgebra-sparse/src/ops/mod.rs
Normal file
34
nalgebra-sparse/src/ops/mod.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
//! TODO
|
||||||
|
|
||||||
|
pub mod serial;
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||||
|
pub enum Transposition {
|
||||||
|
/// TODO
|
||||||
|
Transpose,
|
||||||
|
/// TODO
|
||||||
|
NoTranspose,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transposition {
|
||||||
|
/// TODO
|
||||||
|
pub fn is_transpose(&self) -> bool {
|
||||||
|
self == &Self::Transpose
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn from_bool(transpose: bool) -> Self {
|
||||||
|
if transpose { Self::Transpose } else { Self::NoTranspose }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn transpose() -> Transposition {
|
||||||
|
Transposition::Transpose
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn no_transpose() -> Transposition {
|
||||||
|
Transposition::NoTranspose
|
||||||
|
}
|
@ -12,6 +12,8 @@ use num_traits::{One, Zero};
|
|||||||
///
|
///
|
||||||
/// If `beta == 0`, the elements in `y` are never read.
|
/// If `beta == 0`, the elements in `y` are never read.
|
||||||
///
|
///
|
||||||
|
/// TODO: Rethink this function
|
||||||
|
///
|
||||||
/// Panics
|
/// Panics
|
||||||
/// ------
|
/// ------
|
||||||
///
|
///
|
68
nalgebra-sparse/src/ops/serial/csr.rs
Normal file
68
nalgebra-sparse/src/ops/serial/csr.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::ops::Transposition;
|
||||||
|
use nalgebra::{DVectorSlice, Scalar, DMatrixSlice, DVectorSliceMut, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||||
|
use num_traits::{Zero, One};
|
||||||
|
|
||||||
|
/// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
|
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transposition,
|
||||||
|
a: &CsrMatrix<T>,
|
||||||
|
trans_b: Transposition,
|
||||||
|
b: impl Into<DMatrixSlice<'a, T>>)
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
spmm_csr_dense_(c.into(), beta, alpha, trans_a, a, trans_b, b.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transposition,
|
||||||
|
a: &CsrMatrix<T>,
|
||||||
|
trans_b: Transposition,
|
||||||
|
b: DMatrixSlice<T>)
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
|
||||||
|
|
||||||
|
if trans_a.is_transpose() {
|
||||||
|
// In this case, we have to pre-multiply C by beta
|
||||||
|
c *= beta;
|
||||||
|
|
||||||
|
for k in 0..a.nrows() {
|
||||||
|
let a_row_k = a.row(k);
|
||||||
|
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
|
||||||
|
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
||||||
|
let mut c_row_i = c.row_mut(i);
|
||||||
|
if trans_b.is_transpose() {
|
||||||
|
let b_col_k = b.column(k);
|
||||||
|
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let b_row_k = b.row(k);
|
||||||
|
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for j in 0..c.ncols() {
|
||||||
|
let mut c_col_j = c.column_mut(j);
|
||||||
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
||||||
|
let mut dot_ij = T::zero();
|
||||||
|
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let b_contrib =
|
||||||
|
if trans_b.is_transpose() { b.index((j, k)) } else { b.index((k, j)) };
|
||||||
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||||
|
}
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
37
nalgebra-sparse/src/ops/serial/mod.rs
Normal file
37
nalgebra-sparse/src/ops/serial/mod.rs
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
//! TODO
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
macro_rules! assert_compatible_spmm_dims {
|
||||||
|
($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => {
|
||||||
|
use crate::ops::Transposition::{Transpose, NoTranspose};
|
||||||
|
match ($trans_a, $trans_b) {
|
||||||
|
(NoTranspose, NoTranspose) => {
|
||||||
|
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
||||||
|
assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()");
|
||||||
|
assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()");
|
||||||
|
},
|
||||||
|
(Transpose, NoTranspose) => {
|
||||||
|
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
||||||
|
assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()");
|
||||||
|
assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()");
|
||||||
|
},
|
||||||
|
(NoTranspose, Transpose) => {
|
||||||
|
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
||||||
|
assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()");
|
||||||
|
assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()");
|
||||||
|
},
|
||||||
|
(Transpose, Transpose) => {
|
||||||
|
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
||||||
|
assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()");
|
||||||
|
assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mod coo;
|
||||||
|
mod csr;
|
||||||
|
|
||||||
|
pub use coo::*;
|
||||||
|
pub use csr::*;
|
@ -1,6 +1,15 @@
|
|||||||
use nalgebra_sparse::coo::CooMatrix;
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
use nalgebra_sparse::ops::spmv_coo;
|
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense};
|
||||||
use nalgebra::{DVector, DMatrix};
|
use nalgebra_sparse::ops::{no_transpose, Transposition};
|
||||||
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
|
use nalgebra_sparse::proptest::csr;
|
||||||
|
|
||||||
|
use nalgebra::{DVector, DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
||||||
|
use nalgebra::proptest::matrix;
|
||||||
|
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
use std::panic::catch_unwind;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmv_coo_agrees_with_dense_gemv() {
|
fn spmv_coo_agrees_with_dense_gemv() {
|
||||||
@ -26,3 +35,137 @@ fn spmv_coo_agrees_with_dense_gemv() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SpmmCsrDenseArgs<T: Scalar> {
|
||||||
|
c: DMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transposition,
|
||||||
|
a: CsrMatrix<T>,
|
||||||
|
trans_b: Transposition,
|
||||||
|
b: DMatrix<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
|
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
|
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
||||||
|
let max_nnz = 40;
|
||||||
|
let value_strategy = -5 ..= 5;
|
||||||
|
let c_rows = 0 ..= 6usize;
|
||||||
|
let c_cols = 0 ..= 6usize;
|
||||||
|
let common_dim = 0 ..= 6usize;
|
||||||
|
let trans_strategy = trans_strategy();
|
||||||
|
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
||||||
|
|
||||||
|
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
|
||||||
|
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
|
||||||
|
let a_shape =
|
||||||
|
if trans_a.is_transpose() { (common_dim, c.nrows()) }
|
||||||
|
else { (c.nrows(), common_dim) };
|
||||||
|
let b_shape =
|
||||||
|
if trans_b.is_transpose() { (c.ncols(), common_dim) }
|
||||||
|
else { (common_dim, c.ncols()) };
|
||||||
|
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz);
|
||||||
|
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
||||||
|
|
||||||
|
// We use the same values for alpha, beta parameters as for matrix elements
|
||||||
|
let alpha = value_strategy.clone();
|
||||||
|
let beta = value_strategy.clone();
|
||||||
|
|
||||||
|
(Just(c), beta, alpha, Just(trans_a), a, Just(trans_b), b)
|
||||||
|
}).prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
||||||
|
SpmmCsrDenseArgs {
|
||||||
|
c,
|
||||||
|
beta,
|
||||||
|
alpha,
|
||||||
|
trans_a,
|
||||||
|
a,
|
||||||
|
trans_b,
|
||||||
|
b,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
|
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||||
|
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn trans_strategy() -> impl Strategy<Value=Transposition> + Clone {
|
||||||
|
proptest::bool::ANY.prop_map(Transposition::from_bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to help us call dense GEMM with our transposition parameters
|
||||||
|
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||||
|
beta: i32,
|
||||||
|
alpha: i32,
|
||||||
|
trans_a: Transposition,
|
||||||
|
a: impl Into<DMatrixSlice<'a, i32>>,
|
||||||
|
trans_b: Transposition,
|
||||||
|
b: impl Into<DMatrixSlice<'a, i32>>)
|
||||||
|
{
|
||||||
|
let mut c = c.into();
|
||||||
|
let a = a.into();
|
||||||
|
let b = b.into();
|
||||||
|
|
||||||
|
use Transposition::{Transpose, NoTranspose};
|
||||||
|
match (trans_a, trans_b) {
|
||||||
|
(NoTranspose, NoTranspose) => c.gemm(alpha, &a, &b, beta),
|
||||||
|
(Transpose, NoTranspose) => c.gemm(alpha, &a.transpose(), &b, beta),
|
||||||
|
(NoTranspose, Transpose) => c.gemm(alpha, &a, &b.transpose(), beta),
|
||||||
|
(Transpose, Transpose) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csr_dense_agrees_with_dense_result(
|
||||||
|
SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b }
|
||||||
|
in spmm_csr_dense_args_strategy()
|
||||||
|
) {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b);
|
||||||
|
|
||||||
|
let mut gemm_result = c.clone();
|
||||||
|
dense_gemm(&mut gemm_result, beta, alpha, trans_a, &DMatrix::from(&a), trans_b, &b);
|
||||||
|
|
||||||
|
prop_assert_eq!(spmm_result, gemm_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csr_dense_panics_on_dim_mismatch(
|
||||||
|
(alpha, beta, c, a, b, trans_a, trans_b)
|
||||||
|
in (-5 ..= 5, -5 ..= 5, dense_strategy(), csr_strategy(),
|
||||||
|
dense_strategy(), trans_strategy(), trans_strategy())
|
||||||
|
) {
|
||||||
|
// We refer to `A * B` as the "product"
|
||||||
|
let product_rows = if trans_a.is_transpose() { a.ncols() } else { a.nrows() };
|
||||||
|
let product_cols = if trans_b.is_transpose() { b.nrows() } else { b.ncols() };
|
||||||
|
// Determine the common dimension in the product
|
||||||
|
// from the perspective of a and b, respectively
|
||||||
|
let product_a_common = if trans_a.is_transpose() { a.nrows() } else { a.ncols() };
|
||||||
|
let product_b_common = if trans_b.is_transpose() { b.ncols() } else { b.nrows() };
|
||||||
|
|
||||||
|
let dims_are_compatible = product_rows == c.nrows()
|
||||||
|
&& product_cols == c.ncols()
|
||||||
|
&& product_a_common == product_b_common;
|
||||||
|
|
||||||
|
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||||
|
// skip the test, so we assume that they are not.
|
||||||
|
prop_assume!(!dims_are_compatible);
|
||||||
|
|
||||||
|
let result = catch_unwind(|| {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b);
|
||||||
|
});
|
||||||
|
|
||||||
|
prop_assert!(result.is_err(),
|
||||||
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user