Implement spmm_pattern

This commit is contained in:
Andreas Longva 2020-12-14 16:55:06 +01:00
parent c4285d9fb3
commit 9db17f00e7
4 changed files with 80 additions and 11 deletions

View File

@ -1,10 +1,11 @@
use crate::csr::CsrMatrix;
use crate::ops::{Transpose};
use crate::SparseEntryMut;
use crate::ops::serial::{OperationError, OperationErrorType};
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
use num_traits::{Zero, One};
use crate::ops::serial::{OperationError, OperationErrorType};
use std::sync::Arc;
use crate::SparseEntryMut;
use std::borrow::Cow;
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,

View File

@ -32,7 +32,7 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
for lane_idx in 0 .. a.major_dim() {
let lane_a = a.lane(lane_idx);
let lane_b = b.lane(lane_idx);
indices.extend(iterate_intersection(lane_a, lane_b));
indices.extend(iterate_union(lane_a, lane_b));
offsets.push(indices.len());
}
@ -43,11 +43,44 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
swap(input_pattern, &mut new_pattern);
}
/// Iterate over the intersection of the two sets represented by sorted slices
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
// TODO: Proper error message
assert_eq!(a.minor_dim(), b.major_dim());
let mut offsets = Vec::new();
let mut indices = Vec::new();
offsets.push(0);
let mut c_lane_workspace = Vec::new();
for i in 0 .. a.major_dim() {
let a_lane_i = a.lane(i);
let c_lane_i_offset = *offsets.last().unwrap();
for &k in a_lane_i {
// We have that the set of elements in lane i in C is given by the union of all
// B_k, where B_k is the set of indices in lane k of B. More precisely, let C_i
// denote the set of indices in lane i in C, and similarly for A_i and B_k. Then
// C_i = union B_k for all k in A_i
// We incrementally compute C_i by incrementally computing the union of C_i with
// B_k until we're through all k in A_i.
let b_lane_k = b.lane(k);
let c_lane_i = &indices[c_lane_i_offset..];
c_lane_workspace.clear();
c_lane_workspace.extend(iterate_union(c_lane_i, b_lane_k));
indices.truncate(c_lane_i_offset);
indices.append(&mut c_lane_workspace);
}
offsets.push(indices.len());
}
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices)
.expect("Internal error: Invalid pattern during matrix multiplication pattern construction")
}
/// Iterate over the union of the two sets represented by sorted slices
/// (with unique elements)
fn iterate_intersection<'a>(mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
// TODO: Can use a kind of simultaneous exponential search to speed things up here
fn iterate_union<'a>(mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
iter::from_fn(move || {
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
let item = if a_item < b_item {

View File

@ -2,6 +2,7 @@ use proptest::strategy::Strategy;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, csc};
use nalgebra_sparse::csc::CscMatrix;
use std::ops::RangeInclusive;
#[macro_export]
macro_rules! assert_panics {
@ -24,10 +25,13 @@ macro_rules! assert_panics {
}};
}
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
pub const PROPTEST_MAX_NNZ: usize = 40;
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
}
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
csc(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
}

View File

@ -1,4 +1,4 @@
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spadd_csr};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr};
use nalgebra_sparse::ops::{Transpose};
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
@ -12,7 +12,7 @@ use proptest::prelude::*;
use std::panic::catch_unwind;
use std::sync::Arc;
use crate::common::csr_strategy;
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
@ -127,6 +127,15 @@ fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, Spars
})
}
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
pattern_strategy()
.prop_flat_map(|a| {
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
(Just(a), b)
})
}
/// Helper function to help us call dense GEMM with our transposition parameters
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
beta: i32,
@ -269,4 +278,26 @@ proptest! {
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
#[test]
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
{
// (a, b) are multiplication-wise dimensionally compatible patterns
let c_pattern = spmm_pattern(&a, &b);
// To verify the pattern, we construct CSR matrices with positive integer entries
// corresponding to a and b, and convert them to dense matrices.
// The product of these dense matrices will then have non-zeros in exactly the same locations
// as the result of "multiplying" the sparsity patterns
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
.unwrap();
let a_dense = DMatrix::from(&a_csr);
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
.unwrap();
let b_dense = DMatrix::from(&b_csr);
let c_dense = a_dense * b_dense;
let c_csr = CsrMatrix::from(&c_dense);
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
}
}