Replace spmm_pattern with spmm_{csr/csc}_pattern
This commit is contained in:
parent
cb0f9a5190
commit
3eab45d81b
|
@ -2,8 +2,7 @@ use crate::csr::CsrMatrix;
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Sub, Neg};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_pattern,
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern, spmm_csr_pattern, spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense, spmm_csc_pattern};
|
||||||
spmm_csr_prealloc, spmm_csc_prealloc, spmm_csc_dense, spmm_csr_dense};
|
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
use nalgebra::{ClosedAdd, ClosedMul, ClosedSub, ClosedDiv, Scalar, Matrix, Dim,
|
||||||
DMatrixSlice, DMatrix, Dynamic};
|
DMatrixSlice, DMatrix, Dynamic};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
|
@ -106,9 +105,9 @@ macro_rules! impl_spmm {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
|
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc);
|
||||||
// Need to switch order of operations for CSC pattern
|
// Need to switch order of operations for CSC pattern
|
||||||
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);
|
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc);
|
||||||
|
|
||||||
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
|
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
|
||||||
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
|
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
|
||||||
|
|
|
@ -36,7 +36,23 @@ pub fn spadd_pattern(a: &SparsityPattern,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||||
pub fn spmm_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
///
|
||||||
|
/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also
|
||||||
|
/// represented as the sparsity pattern of a CSC matrix.
|
||||||
|
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
|
// Let C = A * B in CSC format. We note that
|
||||||
|
// C^T = B^T * A^T.
|
||||||
|
// Since the interpretation of a CSC matrix in CSR format represents the transpose of the
|
||||||
|
// matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b,
|
||||||
|
// which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format
|
||||||
|
spmm_csr_pattern(b, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||||
|
///
|
||||||
|
/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also
|
||||||
|
/// represented as the sparsity pattern of a CSR matrix.
|
||||||
|
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||||
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
assert_eq!(a.minor_dim(), b.major_dim(), "a and b must have compatible dimensions");
|
||||||
|
|
||||||
let mut offsets = Vec::new();
|
let mut offsets = Vec::new();
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
|
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, PROPTEST_I32_VALUE_STRATEGY, non_zero_i32_value_strategy, value_strategy};
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spadd_csr_prealloc, spadd_csc_prealloc, spmm_csr_prealloc, spmm_csc_prealloc, spsolve_csc_lower_triangular, spmm_csr_pattern};
|
||||||
spadd_csr_prealloc, spadd_csc_prealloc,
|
|
||||||
spmm_csr_prealloc, spmm_csc_prealloc,
|
|
||||||
spsolve_csc_lower_triangular};
|
|
||||||
use nalgebra_sparse::ops::{Op};
|
use nalgebra_sparse::ops::{Op};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::csc::CscMatrix;
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
@ -188,7 +185,7 @@ fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPat
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
/// Constructs pairs (a, b) where a and b have compatible dimensions for a matrix product
|
||||||
fn spmm_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
fn spmm_csr_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||||
pattern_strategy()
|
pattern_strategy()
|
||||||
.prop_flat_map(|a| {
|
.prop_flat_map(|a| {
|
||||||
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
let b = sparsity_pattern(Just(a.minor_dim()), PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ);
|
||||||
|
@ -215,11 +212,11 @@ struct SpmmCscArgs<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
spmm_pattern_strategy()
|
spmm_csr_pattern_strategy()
|
||||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||||
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
let a_values = vec![PROPTEST_I32_VALUE_STRATEGY; a_pattern.nnz()];
|
||||||
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
let b_values = vec![PROPTEST_I32_VALUE_STRATEGY; b_pattern.nnz()];
|
||||||
let c_pattern = spmm_pattern(&a_pattern, &b_pattern);
|
let c_pattern = spmm_csr_pattern(&a_pattern, &b_pattern);
|
||||||
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
let c_values = vec![PROPTEST_I32_VALUE_STRATEGY; c_pattern.nnz()];
|
||||||
let a = a_values.prop_map(move |values|
|
let a = a_values.prop_map(move |values|
|
||||||
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
CsrMatrix::try_from_pattern_and_values(a_pattern.clone(), values).unwrap());
|
||||||
|
@ -479,10 +476,10 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_pattern_test((a, b) in spmm_pattern_strategy())
|
fn spmm_csr_pattern_test((a, b) in spmm_csr_pattern_strategy())
|
||||||
{
|
{
|
||||||
// (a, b) are multiplication-wise dimensionally compatible patterns
|
// (a, b) are multiplication-wise dimensionally compatible patterns
|
||||||
let c_pattern = spmm_pattern(&a, &b);
|
let c_pattern = spmm_csr_pattern(&a, &b);
|
||||||
|
|
||||||
// To verify the pattern, we construct CSR matrices with positive integer entries
|
// To verify the pattern, we construct CSR matrices with positive integer entries
|
||||||
// corresponding to a and b, and convert them to dense matrices.
|
// corresponding to a and b, and convert them to dense matrices.
|
||||||
|
|
Loading…
Reference in New Issue