diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index 1b337397..b8841bf2 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -2,8 +2,7 @@ use crate::csr::CsrMatrix; use crate::csc::CscMatrix; use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg}; -use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern, - spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense}; +use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern}; use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim, DMatrixSlice, DMatrix, Dynamic}; use num_traits::{Zero, One}; @@ -106,9 +105,9 @@ macro_rules! impl_spmm { } } -impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc); +impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc); // Need to switch order of operations for CSC pattern -impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc); +impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc); /// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary /// is that we are not able to implement Mul> for all T generically due to orphan rules. diff --git a/nalgebra-sparse/src/ops/serial/pattern.rs b/nalgebra-sparse/src/ops/serial/pattern.rs index 39b8a1c1..276100f3 100644 --- a/nalgebra-sparse/src/ops/serial/pattern.rs +++ b/nalgebra-sparse/src/ops/serial/pattern.rs @@ -36,7 +36,23 @@ pub fn spadd_pattern(a: &SparsityPattern, } /// Sparse matrix multiplication pattern construction, `C <- A * B`. -pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern { +/// +/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also +/// represented as the sparsity pattern of a CSC matrix. +pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern { + // Let C = A * B in CSC format. We note that + // C^T = B^T * A^T. + // Since the interpretation of a CSC matrix in CSR format represents the transpose of the + // matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b, + // which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format + spmm_csr_pattern(b, a) +} + +/// Sparse matrix multiplication pattern construction, `C <- A * B`. +/// +/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also +/// represented as the sparsity pattern of a CSR matrix. +pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern { assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions"); let mut offsets = Vec::new(); diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 5a9aafe7..cef71378 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,8 +1,5 @@ use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy}; -use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern, - spadd_csr_prealloc, spadd_csc_prealloc, - spmm_csr_prealloc, spmm_csc_prealloc, - spsolve_csc_lower_triangular}; +use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern}; use nalgebra_sparse::ops::{Op}; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csc::CscMatrix; @@ -188,7 +185,7 @@ fn spadd_pattern_strategy() -> impl Strategy impl Strategy { +fn spmm_csr_pattern_strategy() -> impl Strategy { pattern_strategy() .prop_flat_map(|a| { let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ); @@ -215,11 +212,11 @@ struct SpmmCscArgs { } fn spmm_csr_prealloc_args_strategy() -> impl Strategy> { - spmm_pattern_strategy() + spmm_csr_pattern_strategy() .prop_flat_map(|(a_pattern, b_pattern)| { let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()]; let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()]; - let c_pattern = spmm_pattern(&a_pattern, &b_pattern); + let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern); let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()]; let a = a_values.prop_map(move |values| CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap()); @@ -479,10 +476,10 @@ proptest! { } #[test] - fn spmm_pattern_test((a, b) in spmm_pattern_strategy()) + fn spmm_csr_pattern_test((a, b) in spmm_csr_pattern_strategy()) { // (a, b) are multiplication-wise dimensionally compatible patterns - let c_pattern = spmm_pattern(&a, &b); + let c_pattern = spmm_csr_pattern(&a, &b); // To verify the pattern, we construct CSR matrices with positive integer entries // corresponding to a and b, and convert them to dense matrices.