forked from M-Labs/nalgebra
Simplify spadd_pattern API and name
This commit is contained in:
parent
6a100c085a
commit
c6a8fcdee2
@ -1,12 +1,11 @@
|
|||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul};
|
||||||
use crate::ops::serial::{spadd_csr, spadd_build_pattern, spmm_pattern, spmm_csr};
|
use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::ops::Transpose;
|
use crate::ops::Transpose;
|
||||||
use crate::pattern::SparsityPattern;
|
|
||||||
|
|
||||||
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||||
where
|
where
|
||||||
@ -17,8 +16,7 @@ where
|
|||||||
type Output = CsrMatrix<T>;
|
type Output = CsrMatrix<T>;
|
||||||
|
|
||||||
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||||
let mut pattern = SparsityPattern::new(self.nrows(), self.ncols());
|
let pattern = spadd_pattern(self.pattern(), rhs.pattern());
|
||||||
spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern());
|
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use crate::pattern::SparsityPattern;
|
use crate::pattern::SparsityPattern;
|
||||||
|
|
||||||
use std::mem::swap;
|
|
||||||
use std::iter;
|
use std::iter;
|
||||||
|
|
||||||
/// Sparse matrix addition pattern construction, `C <- A + B`.
|
/// Sparse matrix addition pattern construction, `C <- A + B`.
|
||||||
@ -9,21 +8,15 @@ use std::iter;
|
|||||||
/// The patterns are assumed to have the same major and minor dimensions. In other words,
|
/// The patterns are assumed to have the same major and minor dimensions. In other words,
|
||||||
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
|
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
|
||||||
/// CSR or CSC.
|
/// CSR or CSC.
|
||||||
/// TODO: Explain that output pattern is only used to avoid allocations
|
pub fn spadd_pattern(a: &SparsityPattern,
|
||||||
pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
b: &SparsityPattern) -> SparsityPattern
|
||||||
a: &SparsityPattern,
|
|
||||||
b: &SparsityPattern)
|
|
||||||
{
|
{
|
||||||
// TODO: Proper error messages
|
// TODO: Proper error messages
|
||||||
assert_eq!(a.major_dim(), b.major_dim());
|
assert_eq!(a.major_dim(), b.major_dim(), "Patterns must have identical major dimensions.");
|
||||||
assert_eq!(a.minor_dim(), b.minor_dim());
|
assert_eq!(a.minor_dim(), b.minor_dim(), "Patterns must have identical minor dimensions.");
|
||||||
|
|
||||||
let input_pattern = pattern;
|
let mut offsets = Vec::new();
|
||||||
let mut temp_pattern = SparsityPattern::new(a.major_dim(), b.minor_dim());
|
let mut indices = Vec::new();
|
||||||
swap(input_pattern, &mut temp_pattern);
|
|
||||||
let (mut offsets, mut indices) = temp_pattern.disassemble();
|
|
||||||
|
|
||||||
offsets.clear();
|
|
||||||
offsets.reserve(a.major_dim() + 1);
|
offsets.reserve(a.major_dim() + 1);
|
||||||
indices.clear();
|
indices.clear();
|
||||||
|
|
||||||
@ -37,10 +30,9 @@ pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||||
let mut new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
SparsityPattern::try_from_offsets_and_indices(
|
||||||
a.major_dim(), a.minor_dim(), offsets, indices)
|
a.major_dim(), a.minor_dim(), offsets, indices)
|
||||||
.expect("Pattern must be valid by definition");
|
.expect("Internal error: Pattern must be valid by definition")
|
||||||
swap(input_pattern, &mut new_pattern);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
PROPTEST_I32_VALUE_STRATEGY};
|
PROPTEST_I32_VALUE_STRATEGY};
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_build_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
||||||
use nalgebra_sparse::ops::{Transpose};
|
use nalgebra_sparse::ops::{Transpose};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||||
@ -88,8 +88,7 @@ fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
|||||||
|
|
||||||
spadd_build_pattern_strategy()
|
spadd_build_pattern_strategy()
|
||||||
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||||
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
let c_pattern = spadd_pattern(&a_pattern, &b_pattern);
|
||||||
spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern);
|
|
||||||
|
|
||||||
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
|
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
|
||||||
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
||||||
@ -248,11 +247,10 @@ proptest! {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spadd_build_pattern_test((c, (a, b)) in (pattern_strategy(), spadd_build_pattern_strategy()))
|
fn spadd_pattern_test((a, b) in spadd_build_pattern_strategy())
|
||||||
{
|
{
|
||||||
// (a, b) are dimensionally compatible patterns, whereas c is an *arbitrary* pattern
|
// (a, b) are dimensionally compatible patterns
|
||||||
let mut pattern_result = c.clone();
|
let pattern_result = spadd_pattern(&a, &b);
|
||||||
spadd_build_pattern(&mut pattern_result, &a, &b);
|
|
||||||
|
|
||||||
// To verify the pattern, we construct CSR matrices with positive integer entries
|
// To verify the pattern, we construct CSR matrices with positive integer entries
|
||||||
// corresponding to a and b, and convert them to dense matrices.
|
// corresponding to a and b, and convert them to dense matrices.
|
||||||
|
Loading…
Reference in New Issue
Block a user