Implement spmm_csr
This commit is contained in:
parent
9db17f00e7
commit
2d534a6133
|
@ -156,3 +156,67 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn spmm_csr_unexpected_entry() -> OperationError {
|
||||||
|
OperationError::from_type_and_message(
|
||||||
|
OperationErrorType::InvalidPattern,
|
||||||
|
String::from("Found unexpected entry that is not present in `c`."))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
pub fn spmm_csr<'a, T>(
|
||||||
|
c: &mut CsrMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transpose,
|
||||||
|
a: &CsrMatrix<T>,
|
||||||
|
trans_b: Transpose,
|
||||||
|
b: &CsrMatrix<T>)
|
||||||
|
-> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
|
||||||
|
|
||||||
|
if !trans_a.to_bool() && !trans_b.to_bool() {
|
||||||
|
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||||
|
for c_ij in c_row_i.values_mut() {
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let b_row_k = b.row(k);
|
||||||
|
let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut();
|
||||||
|
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||||
|
for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) {
|
||||||
|
// Determine the location in C to append the value
|
||||||
|
let (c_local_idx, _) = c_row_i_cols.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == j)
|
||||||
|
.ok_or_else(spmm_csr_unexpected_entry)?;
|
||||||
|
|
||||||
|
c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
c_row_i_cols = &c_row_i_cols[c_local_idx ..];
|
||||||
|
c_row_i_values = &mut c_row_i_values[c_local_idx ..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
|
// and calling the operation again without transposition
|
||||||
|
// TODO: At least use workspaces to allow control of allocations. Maybe
|
||||||
|
// consider implementing certain patterns (like A^T * B) explicitly
|
||||||
|
let (a, b) = {
|
||||||
|
use Cow::*;
|
||||||
|
match (trans_a, trans_b) {
|
||||||
|
(Transpose(false), Transpose(false)) => unreachable!(),
|
||||||
|
(Transpose(true), Transpose(false)) => (Owned(a.transpose()), Borrowed(b)),
|
||||||
|
(Transpose(false), Transpose(true)) => (Borrowed(a), Owned(b.transpose())),
|
||||||
|
(Transpose(true), Transpose(true)) => (Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
spmm_csr(c, beta, alpha, Transpose(false), a.as_ref(), Transpose(false), b.as_ref())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,31 +45,41 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
||||||
|
|
||||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||||
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
// TODO: Proper error message
|
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
||||||
assert_eq!(a.minor_dim(), b.major_dim());
|
|
||||||
|
|
||||||
let mut offsets = Vec::new();
|
let mut offsets = Vec::new();
|
||||||
let mut indices = Vec::new();
|
let mut indices = Vec::new();
|
||||||
offsets.push(0);
|
offsets.push(0);
|
||||||
|
|
||||||
let mut c_lane_workspace = Vec::new();
|
// Keep a vector of whether we have visited a particular minor index when working
|
||||||
|
// on a major lane
|
||||||
|
// TODO: Consider using a bitvec or similar here to reduce pressure on memory
|
||||||
|
// (would cut memory use to 1/8, which might help reduce cache misses)
|
||||||
|
let mut visited = vec![false; b.minor_dim()];
|
||||||
|
|
||||||
for i in 0 .. a.major_dim() {
|
for i in 0 .. a.major_dim() {
|
||||||
let a_lane_i = a.lane(i);
|
let a_lane_i = a.lane(i);
|
||||||
let c_lane_i_offset = *offsets.last().unwrap();
|
let c_lane_i_offset = *offsets.last().unwrap();
|
||||||
for &k in a_lane_i {
|
for &k in a_lane_i {
|
||||||
// We have that the set of elements in lane i in C is given by the union of all
|
|
||||||
// B_k, where B_k is the set of indices in lane k of B. More precisely, let C_i
|
|
||||||
// denote the set of indices in lane i in C, and similarly for A_i and B_k. Then
|
|
||||||
// C_i = union B_k for all k in A_i
|
|
||||||
// We incrementally compute C_i by incrementally computing the union of C_i with
|
|
||||||
// B_k until we're through all k in A_i.
|
|
||||||
let b_lane_k = b.lane(k);
|
let b_lane_k = b.lane(k);
|
||||||
let c_lane_i = &indices[c_lane_i_offset..];
|
|
||||||
c_lane_workspace.clear();
|
for &j in b_lane_k {
|
||||||
c_lane_workspace.extend(iterate_union(c_lane_i, b_lane_k));
|
let have_visited_j = &mut visited[j];
|
||||||
indices.truncate(c_lane_i_offset);
|
if !*have_visited_j {
|
||||||
indices.append(&mut c_lane_workspace);
|
indices.push(j);
|
||||||
|
*have_visited_j = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let c_lane_i = &mut indices[c_lane_i_offset ..];
|
||||||
|
c_lane_i.sort_unstable();
|
||||||
|
|
||||||
|
// Reset visits so that visited[j] == false for all j for the next major lane
|
||||||
|
for j in c_lane_i {
|
||||||
|
visited[*j] = false;
|
||||||
|
}
|
||||||
|
|
||||||
offsets.push(indices.len());
|
offsets.push(indices.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ macro_rules! assert_panics {
|
||||||
|
|
||||||
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||||
|
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
||||||
|
|
||||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr};
|
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY};
|
||||||
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
||||||
use nalgebra_sparse::ops::{Transpose};
|
use nalgebra_sparse::ops::{Transpose};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||||
|
@ -12,8 +14,6 @@ use proptest::prelude::*;
|
||||||
use std::panic::catch_unwind;
|
use std::panic::catch_unwind;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
|
|
||||||
|
|
||||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||||
|
@ -37,11 +37,11 @@ struct SpmmCsrDenseArgs<T: Scalar> {
|
||||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
||||||
let max_nnz = 40;
|
let max_nnz = PROPTEST_MAX_NNZ;
|
||||||
let value_strategy = -5 ..= 5;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
let c_rows = 0 ..= 6usize;
|
let c_rows = PROPTEST_MATRIX_DIM;
|
||||||
let c_cols = 0 ..= 6usize;
|
let c_cols = PROPTEST_MATRIX_DIM;
|
||||||
let common_dim = 0 ..= 6usize;
|
let common_dim = PROPTEST_MATRIX_DIM;
|
||||||
let trans_strategy = trans_strategy();
|
let trans_strategy = trans_strategy();
|
||||||
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
||||||
|
|
||||||
|
@ -84,9 +84,8 @@ struct SpaddCsrArgs<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
let value_strategy = -5 ..= 5;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
|
||||||
// TODO :Support transposition
|
|
||||||
spadd_build_pattern_strategy()
|
spadd_build_pattern_strategy()
|
||||||
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||||
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
||||||
|
@ -107,7 +106,7 @@ fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||||
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
||||||
|
@ -115,14 +114,14 @@ fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
||||||
sparsity_pattern(0 ..= 6usize, 0..= 6usize, 40)
|
sparsity_pattern(PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs pairs (a, b) where a and b have the same dimensions
|
/// Constructs pairs (a, b) where a and b have the same dimensions
|
||||||
fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||||
pattern_strategy()
|
pattern_strategy()
|
||||||
.prop_flat_map(|a| {
|
.prop_flat_map(|a| {
|
||||||
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), 40);
|
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ);
|
||||||
(Just(a), b)
|
(Just(a), b)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -136,6 +135,50 @@ fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPatt
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SpmmCsrArgs<T> {
|
||||||
|
c: CsrMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transpose,
|
||||||
|
a: CsrMatrix<T>,
|
||||||
|
trans_b: Transpose,
|
||||||
|
b: CsrMatrix<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
|
spmm_pattern_strategy()
|
||||||
|
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||||
|
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||||
|
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||||
|
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
||||||
|
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||||
|
let a_pattern = Arc::new(a_pattern);
|
||||||
|
let b_pattern = Arc::new(b_pattern);
|
||||||
|
let c_pattern = Arc::new(c_pattern);
|
||||||
|
let a = a_values.prop_map(move |values|
|
||||||
|
CsrMatrix::try_from_pattern_and_values(Arc::clone(&a_pattern), values).unwrap());
|
||||||
|
let b = b_values.prop_map(move |values|
|
||||||
|
CsrMatrix::try_from_pattern_and_values(Arc::clone(&b_pattern), values).unwrap());
|
||||||
|
let c = c_values.prop_map(move |values|
|
||||||
|
CsrMatrix::try_from_pattern_and_values(Arc::clone(&c_pattern), values).unwrap());
|
||||||
|
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
||||||
|
})
|
||||||
|
.prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
||||||
|
SpmmCsrArgs::<i32> {
|
||||||
|
c,
|
||||||
|
beta,
|
||||||
|
alpha,
|
||||||
|
trans_a,
|
||||||
|
a: if trans_a.to_bool() { a.transpose() } else { a },
|
||||||
|
trans_b,
|
||||||
|
b: if trans_b.to_bool() { b.transpose() } else { b }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to help us call dense GEMM with our transposition parameters
|
/// Helper function to help us call dense GEMM with our transposition parameters
|
||||||
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||||
beta: i32,
|
beta: i32,
|
||||||
|
@ -300,4 +343,59 @@ proptest! {
|
||||||
|
|
||||||
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
|
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b }
|
||||||
|
in spmm_csr_args_strategy()
|
||||||
|
) {
|
||||||
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
|
let mut c_sparse = c.clone();
|
||||||
|
spmm_csr(&mut c_sparse, beta, alpha, trans_a, &a, trans_b, &b).unwrap();
|
||||||
|
|
||||||
|
let mut c_dense = DMatrix::from(&c);
|
||||||
|
let op_a_dense = DMatrix::from(&a);
|
||||||
|
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
|
||||||
|
let op_b_dense = DMatrix::from(&b);
|
||||||
|
let op_b_dense = if trans_b.to_bool() { op_b_dense.transpose() } else { op_b_dense };
|
||||||
|
c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense;
|
||||||
|
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csr_panics_on_dim_mismatch(
|
||||||
|
(alpha, beta, c, a, b, trans_a, trans_b)
|
||||||
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
csr_strategy(),
|
||||||
|
csr_strategy(),
|
||||||
|
csr_strategy(),
|
||||||
|
trans_strategy(),
|
||||||
|
trans_strategy())
|
||||||
|
) {
|
||||||
|
// We refer to `A * B` as the "product"
|
||||||
|
let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
||||||
|
let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() };
|
||||||
|
// Determine the common dimension in the product
|
||||||
|
// from the perspective of a and b, respectively
|
||||||
|
let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
||||||
|
let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() };
|
||||||
|
|
||||||
|
let dims_are_compatible = product_rows == c.nrows()
|
||||||
|
&& product_cols == c.ncols()
|
||||||
|
&& product_a_common == product_b_common;
|
||||||
|
|
||||||
|
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||||
|
// skip the test, so we assume that they are not.
|
||||||
|
prop_assume!(!dims_are_compatible);
|
||||||
|
|
||||||
|
let result = catch_unwind(|| {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spmm_csr(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b).unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
prop_assert!(result.is_err(),
|
||||||
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue