diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs new file mode 100644 index 00000000..e69de29b diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index 3b28ac74..b77d1112 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -3,7 +3,7 @@ use crate::ops::{Transpose}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use num_traits::{Zero, One}; -/// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`. +/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`. pub fn spmm_csr_dense<'a, T>(c: impl Into>, beta: T, alpha: T, diff --git a/nalgebra-sparse/src/ops/serial/mod.rs b/nalgebra-sparse/src/ops/serial/mod.rs index a7615ec4..bb40419f 100644 --- a/nalgebra-sparse/src/ops/serial/mod.rs +++ b/nalgebra-sparse/src/ops/serial/mod.rs @@ -32,6 +32,8 @@ macro_rules! assert_compatible_spmm_dims { mod coo; mod csr; +mod pattern; pub use coo::*; -pub use csr::*; \ No newline at end of file +pub use csr::*; +pub use pattern::*; \ No newline at end of file diff --git a/nalgebra-sparse/src/ops/serial/pattern.rs b/nalgebra-sparse/src/ops/serial/pattern.rs new file mode 100644 index 00000000..6507ad1b --- /dev/null +++ b/nalgebra-sparse/src/ops/serial/pattern.rs @@ -0,0 +1,77 @@ +use crate::pattern::SparsityPattern; + +use std::mem::swap; +use std::iter; + +/// Sparse matrix addition pattern construction, `C <- A + B`. +/// +/// Builds the pattern for `C`, which is able to hold the result of the sum `A + B`. +/// The patterns are assumed to have the same major and minor dimensions. In other words, +/// both patterns `A` and `B` must both stem from the same kind of compressed matrix: +/// CSR or CSC. +/// TODO: Explain that output pattern is only used to avoid allocations +pub fn spadd_build_pattern(pattern: &mut SparsityPattern, + a: &SparsityPattern, + b: &SparsityPattern) +{ + // TODO: Proper error messages + assert_eq!(a.major_dim(), b.major_dim()); + assert_eq!(a.minor_dim(), b.minor_dim()); + + let input_pattern = pattern; + let mut temp_pattern = SparsityPattern::new(a.major_dim(), b.minor_dim()); + swap(input_pattern, &mut temp_pattern); + let (mut offsets, mut indices) = temp_pattern.disassemble(); + + offsets.clear(); + offsets.reserve(a.major_dim() + 1); + indices.clear(); + + offsets.push(0); + + for lane_idx in 0 .. a.major_dim() { + let lane_a = a.lane(lane_idx); + let lane_b = b.lane(lane_idx); + indices.extend(iterate_intersection(lane_a, lane_b)); + offsets.push(indices.len()); + } + + // TODO: Consider circumventing format checks? (requires unsafe, should benchmark first) + let mut new_pattern = SparsityPattern::try_from_offsets_and_indices( + a.major_dim(), a.minor_dim(), offsets, indices) + .expect("Pattern must be valid by definition"); + swap(input_pattern, &mut new_pattern); +} + +/// Iterate over the intersection of the two sets represented by sorted slices +/// (with unique elements) +fn iterate_intersection<'a>(mut sorted_a: &'a [usize], + mut sorted_b: &'a [usize]) -> impl Iterator + 'a { + // TODO: Can use a kind of simultaneous exponential search to speed things up here + iter::from_fn(move || { + if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) { + let item = if a_item < b_item { + sorted_a = &sorted_a[1 ..]; + a_item + } else if b_item < a_item { + sorted_b = &sorted_b[1 ..]; + b_item + } else { + // Both lists contain the same element, advance both slices to avoid + // duplicate entries in the result + sorted_a = &sorted_a[1 ..]; + sorted_b = &sorted_b[1 ..]; + a_item + }; + Some(*item) + } else if let Some(a_item) = sorted_a.first() { + sorted_a = &sorted_a[1..]; + Some(*a_item) + } else if let Some(b_item) = sorted_b.first() { + sorted_b = &sorted_b[1..]; + Some(*b_item) + } else { + None + } + }) +} \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index add03a98..4fb9c232 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,8 +1,9 @@ use nalgebra_sparse::coo::CooMatrix; -use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense}; +use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern}; use nalgebra_sparse::ops::{Transpose}; use nalgebra_sparse::csr::CsrMatrix; -use nalgebra_sparse::proptest::csr; +use nalgebra_sparse::proptest::{csr, sparsity_pattern}; +use nalgebra_sparse::pattern::SparsityPattern; use nalgebra::{DVector, DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice}; use nalgebra::proptest::matrix; @@ -10,6 +11,7 @@ use nalgebra::proptest::matrix; use proptest::prelude::*; use std::panic::catch_unwind; +use std::sync::Arc; #[test] fn spmv_coo_agrees_with_dense_gemv() { @@ -99,6 +101,19 @@ fn trans_strategy() -> impl Strategy + Clone { proptest::bool::ANY.prop_map(Transpose) } +fn pattern_strategy() -> impl Strategy { + sparsity_pattern(0 ..= 6usize, 0..= 6usize, 40) +} + +/// Constructs pairs (a, b) where a and b have the same dimensions +fn spadd_build_pattern_strategy() -> impl Strategy { + pattern_strategy() + .prop_flat_map(|a| { + let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), 40); + (Just(a), b) + }) +} + /// Helper function to help us call dense GEMM with our transposition parameters fn dense_gemm<'a>(c: impl Into>, beta: i32, @@ -167,4 +182,26 @@ proptest! { "The SPMM kernel executed successfully despite mismatch dimensions"); } + #[test] + fn spadd_build_pattern_test((c, (a, b)) in (pattern_strategy(), spadd_build_pattern_strategy())) + { + // (a, b) are dimensionally compatible patterns, whereas c is an *arbitrary* pattern + let mut pattern_result = c.clone(); + spadd_build_pattern(&mut pattern_result, &a, &b); + + // To verify the pattern, we construct CSR matrices with positive integer entries + // corresponding to a and b, and convert them to dense matrices. + // The sum of these dense matrices will then have non-zeros in exactly the same locations + // as the result of "adding" the sparsity patterns + let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()]) + .unwrap(); + let a_dense = DMatrix::from(&a_csr); + let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()]) + .unwrap(); + let b_dense = DMatrix::from(&b_csr); + let c_dense = a_dense + b_dense; + let c_csr = CsrMatrix::from(&c_dense); + + prop_assert_eq!(&pattern_result, &*c_csr.pattern()); + } } \ No newline at end of file