Simplify spadd_pattern API and name
This commit is contained in:
parent
6a100c085a
commit
c6a8fcdee2
|
@ -1,12 +1,11 @@
|
|||
use crate::csr::CsrMatrix;
|
||||
|
||||
use std::ops::{Add, Mul};
|
||||
use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr};
|
||||
use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||
use num_traits::{Zero, One};
|
||||
use std::sync::Arc;
|
||||
use crate::ops::Transpose;
|
||||
use crate::pattern::SparsityPattern;
|
||||
|
||||
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||
where
|
||||
|
@ -17,8 +16,7 @@ where
|
|||
type Output = CsrMatrix<T>;
|
||||
|
||||
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||
let mut pattern = SparsityPattern::new(self.nrows(), self.ncols());
|
||||
spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern());
|
||||
let pattern = spadd_pattern(self.pattern(), rhs.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use crate::pattern::SparsityPattern;
|
||||
|
||||
use std::mem::swap;
|
||||
use std::iter;
|
||||
|
||||
/// Sparse matrix addition pattern construction, `C <- A + B`.
|
||||
|
@ -9,21 +8,15 @@ use std::iter;
|
|||
/// The patterns are assumed to have the same major and minor dimensions. In other words,
|
||||
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
|
||||
/// CSR or CSC.
|
||||
/// TODO: Explain that output pattern is only used to avoid allocations
|
||||
pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
||||
a: &SparsityPattern,
|
||||
b: &SparsityPattern)
|
||||
pub fn spadd_pattern(a: &SparsityPattern,
|
||||
b: &SparsityPattern) -> SparsityPattern
|
||||
{
|
||||
// TODO: Proper error messages
|
||||
assert_eq!(a.major_dim(), b.major_dim());
|
||||
assert_eq!(a.minor_dim(), b.minor_dim());
|
||||
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions.");
|
||||
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions.");
|
||||
|
||||
let input_pattern = pattern;
|
||||
let mut temp_pattern = SparsityPattern::new(a.major_dim(), b.minor_dim());
|
||||
swap(input_pattern, &mut temp_pattern);
|
||||
let (mut offsets, mut indices) = temp_pattern.disassemble();
|
||||
|
||||
offsets.clear();
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
offsets.reserve(a.major_dim() + 1);
|
||||
indices.clear();
|
||||
|
||||
|
@ -37,10 +30,9 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
|||
}
|
||||
|
||||
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||
let mut new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
SparsityPattern::try_from_offsets_and_indices(
|
||||
a.major_dim(), a.minor_dim(), offsets, indices)
|
||||
.expect("Pattern must be valid by definition");
|
||||
swap(input_pattern, &mut new_pattern);
|
||||
.expect("Internal error: Pattern must be valid by definition")
|
||||
}
|
||||
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||
PROPTEST_I32_VALUE_STRATEGY};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
||||
use nalgebra_sparse::ops::{Transpose};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||
|
@ -88,8 +88,7 @@ fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
|||
|
||||
spadd_build_pattern_strategy()
|
||||
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
||||
spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern);
|
||||
let c_pattern = spadd_pattern(&a_pattern, &b_pattern);
|
||||
|
||||
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
|
||||
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
||||
|
@ -248,11 +247,10 @@ proptest! {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn spadd_build_pattern_test((c, (a, b)) in (pattern_strategy(), spadd_build_pattern_strategy()))
|
||||
fn spadd_pattern_test((a, b) in spadd_build_pattern_strategy())
|
||||
{
|
||||
// (a, b) are dimensionally compatible patterns, whereas c is an *arbitrary* pattern
|
||||
let mut pattern_result = c.clone();
|
||||
spadd_build_pattern(&mut pattern_result, &a, &b);
|
||||
// (a, b) are dimensionally compatible patterns
|
||||
let pattern_result = spadd_pattern(&a, &b);
|
||||
|
||||
// To verify the pattern, we construct CSR matrices with positive integer entries
|
||||
// corresponding to a and b, and convert them to dense matrices.
|
||||
|
|
Loading…
Reference in New Issue