forked from M-Labs/nalgebra
Implement spmm_pattern
This commit is contained in:
parent
c4285d9fb3
commit
9db17f00e7
@ -1,10 +1,11 @@
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::ops::{Transpose};
|
||||
use crate::SparseEntryMut;
|
||||
use crate::ops::serial::{OperationError, OperationErrorType};
|
||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::ops::serial::{OperationError, OperationErrorType};
|
||||
use std::sync::Arc;
|
||||
use crate::SparseEntryMut;
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
||||
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
|
@ -32,7 +32,7 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
||||
for lane_idx in 0 .. a.major_dim() {
|
||||
let lane_a = a.lane(lane_idx);
|
||||
let lane_b = b.lane(lane_idx);
|
||||
indices.extend(iterate_intersection(lane_a, lane_b));
|
||||
indices.extend(iterate_union(lane_a, lane_b));
|
||||
offsets.push(indices.len());
|
||||
}
|
||||
|
||||
@ -43,11 +43,44 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
||||
swap(input_pattern, &mut new_pattern);
|
||||
}
|
||||
|
||||
/// Iterate over the intersection of the two sets represented by sorted slices
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
// TODO: Proper error message
|
||||
assert_eq!(a.minor_dim(), b.major_dim());
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
offsets.push(0);
|
||||
|
||||
let mut c_lane_workspace = Vec::new();
|
||||
for i in 0 .. a.major_dim() {
|
||||
let a_lane_i = a.lane(i);
|
||||
let c_lane_i_offset = *offsets.last().unwrap();
|
||||
for &k in a_lane_i {
|
||||
// We have that the set of elements in lane i in C is given by the union of all
|
||||
// B_k, where B_k is the set of indices in lane k of B. More precisely, let C_i
|
||||
// denote the set of indices in lane i in C, and similarly for A_i and B_k. Then
|
||||
// C_i = union B_k for all k in A_i
|
||||
// We incrementally compute C_i by incrementally computing the union of C_i with
|
||||
// B_k until we're through all k in A_i.
|
||||
let b_lane_k = b.lane(k);
|
||||
let c_lane_i = &indices[c_lane_i_offset..];
|
||||
c_lane_workspace.clear();
|
||||
c_lane_workspace.extend(iterate_union(c_lane_i, b_lane_k));
|
||||
indices.truncate(c_lane_i_offset);
|
||||
indices.append(&mut c_lane_workspace);
|
||||
}
|
||||
offsets.push(indices.len());
|
||||
}
|
||||
|
||||
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices)
|
||||
.expect("Internal error: Invalid pattern during matrix multiplication pattern construction")
|
||||
}
|
||||
|
||||
/// Iterate over the union of the two sets represented by sorted slices
|
||||
/// (with unique elements)
|
||||
fn iterate_intersection<'a>(mut sorted_a: &'a [usize],
|
||||
fn iterate_union<'a>(mut sorted_a: &'a [usize],
|
||||
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
|
||||
// TODO: Can use a kind of simultaneous exponential search to speed things up here
|
||||
iter::from_fn(move || {
|
||||
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
||||
let item = if a_item < b_item {
|
||||
|
@ -2,6 +2,7 @@ use proptest::strategy::Strategy;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csr, csc};
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! assert_panics {
|
||||
@ -24,10 +25,13 @@ macro_rules! assert_panics {
|
||||
}};
|
||||
}
|
||||
|
||||
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||
|
||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||
csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40)
|
||||
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
}
|
||||
|
||||
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
||||
csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40)
|
||||
csc(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spadd_csr};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr};
|
||||
use nalgebra_sparse::ops::{Transpose};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||
@ -12,7 +12,7 @@ use proptest::prelude::*;
|
||||
use std::panic::catch_unwind;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::common::csr_strategy;
|
||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
|
||||
|
||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
@ -127,6 +127,15 @@ fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, Spars
|
||||
})
|
||||
}
|
||||
|
||||
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
||||
fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||
(Just(a), b)
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to help us call dense GEMM with our transposition parameters
|
||||
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||
beta: i32,
|
||||
@ -269,4 +278,26 @@ proptest! {
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
|
||||
{
|
||||
// (a, b) are multiplication-wise dimensionally compatible patterns
|
||||
let c_pattern = spmm_pattern(&a, &b);
|
||||
|
||||
// To verify the pattern, we construct CSR matrices with positive integer entries
|
||||
// corresponding to a and b, and convert them to dense matrices.
|
||||
// The product of these dense matrices will then have non-zeros in exactly the same locations
|
||||
// as the result of "multiplying" the sparsity patterns
|
||||
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
|
||||
.unwrap();
|
||||
let a_dense = DMatrix::from(&a_csr);
|
||||
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
|
||||
.unwrap();
|
||||
let b_dense = DMatrix::from(&b_csr);
|
||||
let c_dense = a_dense * b_dense;
|
||||
let c_csr = CsrMatrix::from(&c_dense);
|
||||
|
||||
prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user