Simplify spadd_pattern API and name

This commit is contained in:
Andreas Longva 2020-12-16 17:30:48 +01:00
parent 6a100c085a
commit c6a8fcdee2
3 changed files with 15 additions and 27 deletions

View File

@ -1,12 +1,11 @@
use crate::csr::CsrMatrix;
use std::ops::{Add, Mul};
use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr};
use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr};
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
use num_traits::{Zero, One};
use std::sync::Arc;
use crate::ops::Transpose;
use crate::pattern::SparsityPattern;
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
where
@ -17,8 +16,7 @@ where
type Output = CsrMatrix<T>;
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
let mut pattern = SparsityPattern::new(self.nrows(), self.ncols());
spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern());
let pattern = spadd_pattern(self.pattern(), rhs.pattern());
let values = vec![T::zero(); pattern.nnz()];
// We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)

View File

@ -1,6 +1,5 @@
use crate::pattern::SparsityPattern;
use std::mem::swap;
use std::iter;
/// Sparse matrix addition pattern construction, `C <- A + B`.
@ -9,21 +8,15 @@ use std::iter;
/// The patterns are assumed to have the same major and minor dimensions. In other words,
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
/// CSR or CSC.
/// TODO: Explain that output pattern is only used to avoid allocations
pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
a: &SparsityPattern,
b: &SparsityPattern)
pub fn spadd_pattern(a: &SparsityPattern,
b: &SparsityPattern) -> SparsityPattern
{
// TODO: Proper error messages
assert_eq!(a.major_dim(), b.major_dim());
assert_eq!(a.minor_dim(), b.minor_dim());
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions.");
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions.");
let input_pattern = pattern;
let mut temp_pattern = SparsityPattern::new(a.major_dim(), b.minor_dim());
swap(input_pattern, &mut temp_pattern);
let (mut offsets, mut indices) = temp_pattern.disassemble();
offsets.clear();
let mut offsets = Vec::new();
let mut indices = Vec::new();
offsets.reserve(a.major_dim() + 1);
indices.clear();
@ -37,10 +30,9 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
}
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
let mut new_pattern = SparsityPattern::try_from_offsets_and_indices(
SparsityPattern::try_from_offsets_and_indices(
a.major_dim(), a.minor_dim(), offsets, indices)
.expect("Pattern must be valid by definition");
swap(input_pattern, &mut new_pattern);
.expect("Internal error: Pattern must be valid by definition")
}
/// Sparse matrix multiplication pattern construction, `C <- A * B`.

View File

@ -1,6 +1,6 @@
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
PROPTEST_I32_VALUE_STRATEGY};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr, spmm_csr};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr};
use nalgebra_sparse::ops::{Transpose};
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
@ -88,8 +88,7 @@ fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
spadd_build_pattern_strategy()
.prop_flat_map(move |(a_pattern, b_pattern)| {
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern);
let c_pattern = spadd_pattern(&a_pattern, &b_pattern);
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
@ -248,11 +247,10 @@ proptest! {
}
#[test]
fn spadd_build_pattern_test((c, (a, b)) in (pattern_strategy(), spadd_build_pattern_strategy()))
fn spadd_pattern_test((a, b) in spadd_build_pattern_strategy())
{
// (a, b) are dimensionally compatible patterns, whereas c is an *arbitrary* pattern
let mut pattern_result = c.clone();
spadd_build_pattern(&mut pattern_result, &a, &b);
// (a, b) are dimensionally compatible patterns
let pattern_result = spadd_pattern(&a, &b);
// To verify the pattern, we construct CSR matrices with positive integer entries
// corresponding to a and b, and convert them to dense matrices.