Replace spmm_pattern with spmm_{csr/csc}_pattern

This commit is contained in:
Andreas Longva 2021-01-19 17:16:56 +01:00
parent cb0f9a5190
commit 3eab45d81b
3 changed files with 26 additions and 14 deletions

View File

@ -2,8 +2,7 @@ use crate::csr::CsrMatrix;
use crate::csc::CscMatrix;
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern,
spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense};
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
DMatrixSlice, DMatrix, Dynamic};
use num_traits::{Zero, One};
@ -106,9 +105,9 @@ macro_rules! impl_spmm {
}
}
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc);
// Need to switch order of operations for CSC pattern
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc);
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.

View File

@ -36,7 +36,23 @@ pub fn spadd_pattern(a: &SparsityPattern,
}
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
///
/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also
/// represented as the sparsity pattern of a CSC matrix.
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
// Let C = A * B in CSC format. We note that
// C^T = B^T * A^T.
// Since the interpretation of a CSC matrix in CSR format represents the transpose of the
// matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b,
// which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format
spmm_csr_pattern(b, a)
}
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
///
/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also
/// represented as the sparsity pattern of a CSR matrix.
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
let mut offsets = Vec::new();

View File

@ -1,8 +1,5 @@
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
spadd_csr_prealloc, spadd_csc_prealloc,
spmm_csr_prealloc, spmm_csc_prealloc,
spsolve_csc_lower_triangular};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern};
use nalgebra_sparse::ops::{Op};
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::csc::CscMatrix;
@ -188,7 +185,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
}
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
pattern_strategy()
.prop_flat_map(|a| {
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
@ -215,11 +212,11 @@ struct SpmmCscArgs<T> {
}
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
spmm_pattern_strategy()
spmm_csr_pattern_strategy()
.prop_flat_map(|(a_pattern, b_pattern)| {
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
let a = a_values.prop_map(move |values|
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
@ -479,10 +476,10 @@ proptest! {
}
#[test]
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
fn spmm_csr_pattern_test((a, b) in spmm_csr_pattern_strategy())
{
// (a, b) are multiplication-wise dimensionally compatible patterns
let c_pattern = spmm_pattern(&a, &b);
let c_pattern = spmm_csr_pattern(&a, &b);
// To verify the pattern, we construct CSR matrices with positive integer entries
// corresponding to a and b, and convert them to dense matrices.