Replace spmm_pattern with spmm_{csr/csc}_pattern
This commit is contained in:
parent
cb0f9a5190
commit
3eab45d81b
|
@ -2,8 +2,7 @@ use crate::csr::CsrMatrix;
|
|||
use crate::csc::CscMatrix;
|
||||
|
||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern,
|
||||
spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense};
|
||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
||||
DMatrixSlice, DMatrix, Dynamic};
|
||||
use num_traits::{Zero, One};
|
||||
|
@ -106,9 +105,9 @@ macro_rules! impl_spmm {
|
|||
}
|
||||
}
|
||||
|
||||
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
|
||||
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc);
|
||||
// Need to switch order of operations for CSC pattern
|
||||
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);
|
||||
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc);
|
||||
|
||||
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
|
||||
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
|
||||
|
|
|
@ -36,7 +36,23 @@ pub fn spadd_pattern(a: &SparsityPattern,
|
|||
}
|
||||
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
///
|
||||
/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also
|
||||
/// represented as the sparsity pattern of a CSC matrix.
|
||||
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
// Let C = A * B in CSC format. We note that
|
||||
// C^T = B^T * A^T.
|
||||
// Since the interpretation of a CSC matrix in CSR format represents the transpose of the
|
||||
// matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b,
|
||||
// which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format
|
||||
spmm_csr_pattern(b, a)
|
||||
}
|
||||
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
///
|
||||
/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also
|
||||
/// represented as the sparsity pattern of a CSR matrix.
|
||||
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
|
||||
spadd_csr_prealloc, spadd_csc_prealloc,
|
||||
spmm_csr_prealloc, spmm_csc_prealloc,
|
||||
spsolve_csc_lower_triangular};
|
||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern};
|
||||
use nalgebra_sparse::ops::{Op};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
|
@ -188,7 +185,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
|
|||
}
|
||||
|
||||
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
||||
fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||
pattern_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||
|
@ -215,11 +212,11 @@ struct SpmmCscArgs<T> {
|
|||
}
|
||||
|
||||
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||
spmm_pattern_strategy()
|
||||
spmm_csr_pattern_strategy()
|
||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
||||
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
|
||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||
let a = a_values.prop_map(move |values|
|
||||
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
||||
|
@ -479,10 +476,10 @@ proptest! {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
|
||||
fn spmm_csr_pattern_test((a, b) in spmm_csr_pattern_strategy())
|
||||
{
|
||||
// (a, b) are multiplication-wise dimensionally compatible patterns
|
||||
let c_pattern = spmm_pattern(&a, &b);
|
||||
let c_pattern = spmm_csr_pattern(&a, &b);
|
||||
|
||||
// To verify the pattern, we construct CSR matrices with positive integer entries
|
||||
// corresponding to a and b, and convert them to dense matrices.
|
||||
|
|
Loading…
Reference in New Issue