Implement spmm_csr
This commit is contained in:
parent
9db17f00e7
commit
2d534a6133
|
@ -156,3 +156,67 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn spmm_csr_unexpected_entry() -> OperationError {
|
||||
OperationError::from_type_and_message(
|
||||
OperationErrorType::InvalidPattern,
|
||||
String::from("Found unexpected entry that is not present in `c`."))
|
||||
}
|
||||
|
||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
pub fn spmm_csr<'a, T>(
|
||||
c: &mut CsrMatrix<T>,
|
||||
beta: T,
|
||||
alpha: T,
|
||||
trans_a: Transpose,
|
||||
a: &CsrMatrix<T>,
|
||||
trans_b: Transpose,
|
||||
b: &CsrMatrix<T>)
|
||||
-> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
|
||||
|
||||
if !trans_a.to_bool() && !trans_b.to_bool() {
|
||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||
for c_ij in c_row_i.values_mut() {
|
||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
||||
}
|
||||
|
||||
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||
let b_row_k = b.row(k);
|
||||
let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut();
|
||||
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||
for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) {
|
||||
// Determine the location in C to append the value
|
||||
let (c_local_idx, _) = c_row_i_cols.iter()
|
||||
.enumerate()
|
||||
.find(|(_, c_col)| *c_col == j)
|
||||
.ok_or_else(spmm_csr_unexpected_entry)?;
|
||||
|
||||
c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||
c_row_i_cols = &c_row_i_cols[c_local_idx ..];
|
||||
c_row_i_values = &mut c_row_i_values[c_local_idx ..];
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||
// and calling the operation again without transposition
|
||||
// TODO: At least use workspaces to allow control of allocations. Maybe
|
||||
// consider implementing certain patterns (like A^T * B) explicitly
|
||||
let (a, b) = {
|
||||
use Cow::*;
|
||||
match (trans_a, trans_b) {
|
||||
(Transpose(false), Transpose(false)) => unreachable!(),
|
||||
(Transpose(true), Transpose(false)) => (Owned(a.transpose()), Borrowed(b)),
|
||||
(Transpose(false), Transpose(true)) => (Borrowed(a), Owned(b.transpose())),
|
||||
(Transpose(true), Transpose(true)) => (Owned(a.transpose()), Owned(b.transpose()))
|
||||
}
|
||||
};
|
||||
|
||||
spmm_csr(c, beta, alpha, Transpose(false), a.as_ref(), Transpose(false), b.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,31 +45,41 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
|||
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
// TODO: Proper error message
|
||||
assert_eq!(a.minor_dim(), b.major_dim());
|
||||
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
offsets.push(0);
|
||||
|
||||
let mut c_lane_workspace = Vec::new();
|
||||
// Keep a vector of whether we have visited a particular minor index when working
|
||||
// on a major lane
|
||||
// TODO: Consider using a bitvec or similar here to reduce pressure on memory
|
||||
// (would cut memory use to 1/8, which might help reduce cache misses)
|
||||
let mut visited = vec![false; b.minor_dim()];
|
||||
|
||||
for i in 0 .. a.major_dim() {
|
||||
let a_lane_i = a.lane(i);
|
||||
let c_lane_i_offset = *offsets.last().unwrap();
|
||||
for &k in a_lane_i {
|
||||
// We have that the set of elements in lane i in C is given by the union of all
|
||||
// B_k, where B_k is the set of indices in lane k of B. More precisely, let C_i
|
||||
// denote the set of indices in lane i in C, and similarly for A_i and B_k. Then
|
||||
// C_i = union B_k for all k in A_i
|
||||
// We incrementally compute C_i by incrementally computing the union of C_i with
|
||||
// B_k until we're through all k in A_i.
|
||||
let b_lane_k = b.lane(k);
|
||||
let c_lane_i = &indices[c_lane_i_offset..];
|
||||
c_lane_workspace.clear();
|
||||
c_lane_workspace.extend(iterate_union(c_lane_i, b_lane_k));
|
||||
indices.truncate(c_lane_i_offset);
|
||||
indices.append(&mut c_lane_workspace);
|
||||
|
||||
for &j in b_lane_k {
|
||||
let have_visited_j = &mut visited[j];
|
||||
if !*have_visited_j {
|
||||
indices.push(j);
|
||||
*have_visited_j = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let c_lane_i = &mut indices[c_lane_i_offset ..];
|
||||
c_lane_i.sort_unstable();
|
||||
|
||||
// Reset visits so that visited[j] == false for all j for the next major lane
|
||||
for j in c_lane_i {
|
||||
visited[*j] = false;
|
||||
}
|
||||
|
||||
offsets.push(indices.len());
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ macro_rules! assert_panics {
|
|||
|
||||
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
||||
|
||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr};
|
||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||
PROPTEST_I32_VALUE_STRATEGY};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
||||
use nalgebra_sparse::ops::{Transpose};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||
|
@ -12,8 +14,6 @@ use proptest::prelude::*;
|
|||
use std::panic::catch_unwind;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
|
||||
|
||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||
|
@ -37,11 +37,11 @@ struct SpmmCsrDenseArgs<T: Scalar> {
|
|||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
||||
let max_nnz = 40;
|
||||
let value_strategy = -5 ..= 5;
|
||||
let c_rows = 0 ..= 6usize;
|
||||
let c_cols = 0 ..= 6usize;
|
||||
let common_dim = 0 ..= 6usize;
|
||||
let max_nnz = PROPTEST_MAX_NNZ;
|
||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||
let c_rows = PROPTEST_MATRIX_DIM;
|
||||
let c_cols = PROPTEST_MATRIX_DIM;
|
||||
let common_dim = PROPTEST_MATRIX_DIM;
|
||||
let trans_strategy = trans_strategy();
|
||||
let c_matrix_strategy = matrix(value_strategy.clone(), c_rows, c_cols);
|
||||
|
||||
|
@ -84,9 +84,8 @@ struct SpaddCsrArgs<T> {
|
|||
}
|
||||
|
||||
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||
let value_strategy = -5 ..= 5;
|
||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||
|
||||
// TODO :Support transposition
|
||||
spadd_build_pattern_strategy()
|
||||
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
||||
|
@ -107,7 +106,7 @@ fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
|||
}
|
||||
|
||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
||||
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
||||
}
|
||||
|
||||
fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
||||
|
@ -115,14 +114,14 @@ fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
|||
}
|
||||
|
||||
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
||||
sparsity_pattern(0 ..= 6usize, 0..= 6usize, 40)
|
||||
sparsity_pattern(PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
}
|
||||
|
||||
/// Constructs pairs (a, b) where a and b have the same dimensions
|
||||
fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), 40);
|
||||
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
}
|
||||
|
@ -136,6 +135,50 @@ fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPatt
|
|||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SpmmCsrArgs<T> {
|
||||
c: CsrMatrix<T>,
|
||||
beta: T,
|
||||
alpha: T,
|
||||
trans_a: Transpose,
|
||||
a: CsrMatrix<T>,
|
||||
trans_b: Transpose,
|
||||
b: CsrMatrix<T>
|
||||
}
|
||||
|
||||
fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||
spmm_pattern_strategy()
|
||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||
let a_pattern = Arc::new(a_pattern);
|
||||
let b_pattern = Arc::new(b_pattern);
|
||||
let c_pattern = Arc::new(c_pattern);
|
||||
let a = a_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&a_pattern), values).unwrap());
|
||||
let b = b_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&b_pattern), values).unwrap());
|
||||
let c = c_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(Arc::clone(&c_pattern), values).unwrap());
|
||||
let alpha = PROPTEST_I32_VALUE_STRATEGY;
|
||||
let beta = PROPTEST_I32_VALUE_STRATEGY;
|
||||
(c, beta, alpha, trans_strategy(), a, trans_strategy(), b)
|
||||
})
|
||||
.prop_map(|(c, beta, alpha, trans_a, a, trans_b, b)| {
|
||||
SpmmCsrArgs::<i32> {
|
||||
c,
|
||||
beta,
|
||||
alpha,
|
||||
trans_a,
|
||||
a: if trans_a.to_bool() { a.transpose() } else { a },
|
||||
trans_b,
|
||||
b: if trans_b.to_bool() { b.transpose() } else { b }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to help us call dense GEMM with our transposition parameters
|
||||
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||
beta: i32,
|
||||
|
@ -300,4 +343,59 @@ proptest! {
|
|||
|
||||
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b }
|
||||
in spmm_csr_args_strategy()
|
||||
) {
|
||||
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||
let mut c_sparse = c.clone();
|
||||
spmm_csr(&mut c_sparse, beta, alpha, trans_a, &a, trans_b, &b).unwrap();
|
||||
|
||||
let mut c_dense = DMatrix::from(&c);
|
||||
let op_a_dense = DMatrix::from(&a);
|
||||
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
|
||||
let op_b_dense = DMatrix::from(&b);
|
||||
let op_b_dense = if trans_b.to_bool() { op_b_dense.transpose() } else { op_b_dense };
|
||||
c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense;
|
||||
|
||||
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spmm_csr_panics_on_dim_mismatch(
|
||||
(alpha, beta, c, a, b, trans_a, trans_b)
|
||||
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
csr_strategy(),
|
||||
csr_strategy(),
|
||||
csr_strategy(),
|
||||
trans_strategy(),
|
||||
trans_strategy())
|
||||
) {
|
||||
// We refer to `A * B` as the "product"
|
||||
let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
||||
let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() };
|
||||
// Determine the common dimension in the product
|
||||
// from the perspective of a and b, respectively
|
||||
let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
||||
let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() };
|
||||
|
||||
let dims_are_compatible = product_rows == c.nrows()
|
||||
&& product_cols == c.ncols()
|
||||
&& product_a_common == product_b_common;
|
||||
|
||||
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||
// skip the test, so we assume that they are not.
|
||||
prop_assume!(!dims_are_compatible);
|
||||
|
||||
let result = catch_unwind(|| {
|
||||
let mut spmm_result = c.clone();
|
||||
spmm_csr(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b).unwrap();
|
||||
});
|
||||
|
||||
prop_assert!(result.is_err(),
|
||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue