From 9db17f00e76ad901cc89bcfc01f31f57a5dc1845 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Mon, 14 Dec 2020 16:55:06 +0100 Subject: [PATCH] Implement spmm_pattern --- nalgebra-sparse/src/ops/serial/csr.rs | 5 +-- nalgebra-sparse/src/ops/serial/pattern.rs | 43 ++++++++++++++++++++--- nalgebra-sparse/tests/common/mod.rs | 8 +++-- nalgebra-sparse/tests/unit_tests/ops.rs | 35 ++++++++++++++++-- 4 files changed, 80 insertions(+), 11 deletions(-) diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index 42ef6121..16ef3b04 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -1,10 +1,11 @@ use crate::csr::CsrMatrix; use crate::ops::{Transpose}; +use crate::SparseEntryMut; +use crate::ops::serial::{OperationError, OperationErrorType}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use num_traits::{Zero, One}; -use crate::ops::serial::{OperationError, OperationErrorType}; use std::sync::Arc; -use crate::SparseEntryMut; +use std::borrow::Cow; /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`. pub fn spmm_csr_dense<'a, T>(c: impl Into>, diff --git a/nalgebra-sparse/src/ops/serial/pattern.rs b/nalgebra-sparse/src/ops/serial/pattern.rs index 6507ad1b..c8d4a586 100644 --- a/nalgebra-sparse/src/ops/serial/pattern.rs +++ b/nalgebra-sparse/src/ops/serial/pattern.rs @@ -32,7 +32,7 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern, for lane_idx in 0 .. a.major_dim() { let lane_a = a.lane(lane_idx); let lane_b = b.lane(lane_idx); - indices.extend(iterate_intersection(lane_a, lane_b)); + indices.extend(iterate_union(lane_a, lane_b)); offsets.push(indices.len()); } @@ -43,11 +43,44 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern, swap(input_pattern, &mut new_pattern); } -/// Iterate over the intersection of the two sets represented by sorted slices +/// Sparse matrix multiplication pattern construction, `C <- A * B`. +pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern { + // TODO: Proper error message + assert_eq!(a.minor_dim(), b.major_dim()); + + let mut offsets = Vec::new(); + let mut indices = Vec::new(); + offsets.push(0); + + let mut c_lane_workspace = Vec::new(); + for i in 0 .. a.major_dim() { + let a_lane_i = a.lane(i); + let c_lane_i_offset = *offsets.last().unwrap(); + for &k in a_lane_i { + // We have that the set of elements in lane i in C is given by the union of all + // B_k, where B_k is the set of indices in lane k of B. More precisely, let C_i + // denote the set of indices in lane i in C, and similarly for A_i and B_k. Then + // C_i = union B_k for all k in A_i + // We incrementally compute C_i by incrementally computing the union of C_i with + // B_k until we're through all k in A_i. + let b_lane_k = b.lane(k); + let c_lane_i = &indices[c_lane_i_offset..]; + c_lane_workspace.clear(); + c_lane_workspace.extend(iterate_union(c_lane_i, b_lane_k)); + indices.truncate(c_lane_i_offset); + indices.append(&mut c_lane_workspace); + } + offsets.push(indices.len()); + } + + SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices) + .expect("Internal error: Invalid pattern during matrix multiplication pattern construction") +} + +/// Iterate over the union of the two sets represented by sorted slices /// (with unique elements) -fn iterate_intersection<'a>(mut sorted_a: &'a [usize], - mut sorted_b: &'a [usize]) -> impl Iterator + 'a { - // TODO: Can use a kind of simultaneous exponential search to speed things up here +fn iterate_union<'a>(mut sorted_a: &'a [usize], + mut sorted_b: &'a [usize]) -> impl Iterator + 'a { iter::from_fn(move || { if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) { let item = if a_item < b_item { diff --git a/nalgebra-sparse/tests/common/mod.rs b/nalgebra-sparse/tests/common/mod.rs index 6e730b7d..21751e50 100644 --- a/nalgebra-sparse/tests/common/mod.rs +++ b/nalgebra-sparse/tests/common/mod.rs @@ -2,6 +2,7 @@ use proptest::strategy::Strategy; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::{csr, csc}; use nalgebra_sparse::csc::CscMatrix; +use std::ops::RangeInclusive; #[macro_export] macro_rules! assert_panics { @@ -24,10 +25,13 @@ macro_rules! assert_panics { }}; } +pub const PROPTEST_MATRIX_DIM: RangeInclusive = 0..=6; +pub const PROPTEST_MAX_NNZ: usize = 40; + pub fn csr_strategy() -> impl Strategy> { - csr(-5 ..= 5, 0 ..= 6usize, 0 ..= 6usize, 40) + csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) } pub fn csc_strategy() -> impl Strategy> { - csc(-5 ..= 5, 0..=6usize, 0..=6usize, 40) + csc(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) } diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index a7f82b9a..2a46ee32 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,4 +1,4 @@ -use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spadd_csr}; +use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr}; use nalgebra_sparse::ops::{Transpose}; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::{csr, sparsity_pattern}; @@ -12,7 +12,7 @@ use proptest::prelude::*; use std::panic::catch_unwind; use std::sync::Arc; -use crate::common::csr_strategy; +use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ}; /// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1 fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix { @@ -127,6 +127,15 @@ fn spadd_build_pattern_strategy() -> impl Strategy impl Strategy { + pattern_strategy() + .prop_flat_map(|a| { + let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ); + (Just(a), b) + }) +} + /// Helper function to help us call dense GEMM with our transposition parameters fn dense_gemm<'a>(c: impl Into>, beta: i32, @@ -269,4 +278,26 @@ proptest! { prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); } + + #[test] + fn spmm_pattern_test((a, b) in spmm_pattern_strategy()) + { + // (a, b) are multiplication-wise dimensionally compatible patterns + let c_pattern = spmm_pattern(&a, &b); + + // To verify the pattern, we construct CSR matrices with positive integer entries + // corresponding to a and b, and convert them to dense matrices. + // The product of these dense matrices will then have non-zeros in exactly the same locations + // as the result of "multiplying" the sparsity patterns + let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()]) + .unwrap(); + let a_dense = DMatrix::from(&a_csr); + let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()]) + .unwrap(); + let b_dense = DMatrix::from(&b_csr); + let c_dense = a_dense * b_dense; + let c_csr = CsrMatrix::from(&c_dense); + + prop_assert_eq!(&c_pattern, c_csr.pattern().as_ref()); + } } \ No newline at end of file