From c6a8fcdee2b71dce1768728eb7974ff7a6633597 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Wed, 16 Dec 2020 17:30:48 +0100 Subject: [PATCH] Simplify spadd_pattern API and name --- nalgebra-sparse/src/ops/impl_std_ops.rs | 6 ++---- nalgebra-sparse/src/ops/serial/pattern.rs | 24 ++++++++--------------- nalgebra-sparse/tests/unit_tests/ops.rs | 12 +++++------- 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index 17f357a6..34a7bcf5 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -1,12 +1,11 @@ use crate::csr::CsrMatrix; use std::ops::{Add, Mul}; -use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr}; +use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr}; use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use num_traits::{Zero, One}; use std::sync::Arc; use crate::ops::Transpose; -use crate::pattern::SparsityPattern; impl<'a, T> Add<&'a CsrMatrix> for &'a CsrMatrix where @@ -17,8 +16,7 @@ where type Output = CsrMatrix; fn add(self, rhs: &'a CsrMatrix) -> Self::Output { - let mut pattern = SparsityPattern::new(self.nrows(), self.ncols()); - spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern()); + let pattern = spadd_pattern(self.pattern(), rhs.pattern()); let values = vec![T::zero(); pattern.nnz()]; // We are giving data that is valid by definition, so it is safe to unwrap below let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) diff --git a/nalgebra-sparse/src/ops/serial/pattern.rs b/nalgebra-sparse/src/ops/serial/pattern.rs index 2e442cc0..39b8a1c1 100644 --- a/nalgebra-sparse/src/ops/serial/pattern.rs +++ b/nalgebra-sparse/src/ops/serial/pattern.rs @@ -1,6 +1,5 @@ use crate::pattern::SparsityPattern; -use std::mem::swap; use std::iter; /// Sparse matrix addition pattern construction, `C <- A + B`. @@ -9,21 +8,15 @@ use std::iter; /// The patterns are assumed to have the same major and minor dimensions. In other words, /// both patterns `A` and `B` must both stem from the same kind of compressed matrix: /// CSR or CSC. -/// TODO: Explain that output pattern is only used to avoid allocations -pub fn spadd_build_pattern(pattern: &mut SparsityPattern, - a: &SparsityPattern, - b: &SparsityPattern) +pub fn spadd_pattern(a: &SparsityPattern, + b: &SparsityPattern) -> SparsityPattern { // TODO: Proper error messages - assert_eq!(a.major_dim(), b.major_dim()); - assert_eq!(a.minor_dim(), b.minor_dim()); + assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions."); + assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions."); - let input_pattern = pattern; - let mut temp_pattern = SparsityPattern::new(a.major_dim(), b.minor_dim()); - swap(input_pattern, &mut temp_pattern); - let (mut offsets, mut indices) = temp_pattern.disassemble(); - - offsets.clear(); + let mut offsets = Vec::new(); + let mut indices = Vec::new(); offsets.reserve(a.major_dim() + 1); indices.clear(); @@ -37,10 +30,9 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern, } // TODO: Consider circumventing format checks? (requires unsafe, should benchmark first) - let mut new_pattern = SparsityPattern::try_from_offsets_and_indices( + SparsityPattern::try_from_offsets_and_indices( a.major_dim(), a.minor_dim(), offsets, indices) - .expect("Pattern must be valid by definition"); - swap(input_pattern, &mut new_pattern); + .expect("Internal error: Pattern must be valid by definition") } /// Sparse matrix multiplication pattern construction, `C <- A * B`. diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index d5f2bba8..16482171 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,6 +1,6 @@ use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY}; -use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr, spmm_csr}; +use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr}; use nalgebra_sparse::ops::{Transpose}; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::{csr, sparsity_pattern}; @@ -88,8 +88,7 @@ fn spadd_csr_args_strategy() -> impl Strategy> { spadd_build_pattern_strategy() .prop_flat_map(move |(a_pattern, b_pattern)| { - let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim()); - spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern); + let c_pattern = spadd_pattern(&a_pattern, &b_pattern); let a_values = vec![value_strategy.clone(); a_pattern.nnz()]; let c_values = vec![value_strategy.clone(); c_pattern.nnz()]; @@ -248,11 +247,10 @@ proptest! { } #[test] - fn spadd_build_pattern_test((c, (a, b)) in (pattern_strategy(), spadd_build_pattern_strategy())) + fn spadd_pattern_test((a, b) in spadd_build_pattern_strategy()) { - // (a, b) are dimensionally compatible patterns, whereas c is an *arbitrary* pattern - let mut pattern_result = c.clone(); - spadd_build_pattern(&mut pattern_result, &a, &b); + // (a, b) are dimensionally compatible patterns + let pattern_result = spadd_pattern(&a, &b); // To verify the pattern, we construct CSR matrices with positive integer entries // corresponding to a and b, and convert them to dense matrices.