Implement spadd_build_pattern
This commit is contained in:
parent
7c68950614
commit
4420237ede
|
@ -3,7 +3,7 @@ use crate::ops::{Transpose};
|
||||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
|
|
||||||
/// Sparse-dense matrix-matrix multiplication `C = beta * C + alpha * trans(A) * trans(B)`.
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
||||||
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
|
|
|
@ -32,6 +32,8 @@ macro_rules! assert_compatible_spmm_dims {
|
||||||
|
|
||||||
mod coo;
|
mod coo;
|
||||||
mod csr;
|
mod csr;
|
||||||
|
mod pattern;
|
||||||
|
|
||||||
pub use coo::*;
|
pub use coo::*;
|
||||||
pub use csr::*;
|
pub use csr::*;
|
||||||
|
pub use pattern::*;
|
|
@ -0,0 +1,77 @@
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
|
||||||
|
use std::mem::swap;
|
||||||
|
use std::iter;
|
||||||
|
|
||||||
|
/// Sparse matrix addition pattern construction, `C <- A + B`.
|
||||||
|
///
|
||||||
|
/// Builds the pattern for `C`, which is able to hold the result of the sum `A + B`.
|
||||||
|
/// The patterns are assumed to have the same major and minor dimensions. In other words,
|
||||||
|
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
|
||||||
|
/// CSR or CSC.
|
||||||
|
/// TODO: Explain that output pattern is only used to avoid allocations
|
||||||
|
pub fn spadd_build_pattern(pattern: &mut SparsityPattern,
|
||||||
|
a: &SparsityPattern,
|
||||||
|
b: &SparsityPattern)
|
||||||
|
{
|
||||||
|
// TODO: Proper error messages
|
||||||
|
assert_eq!(a.major_dim(), b.major_dim());
|
||||||
|
assert_eq!(a.minor_dim(), b.minor_dim());
|
||||||
|
|
||||||
|
let input_pattern = pattern;
|
||||||
|
let mut temp_pattern = SparsityPattern::new(a.major_dim(), b.minor_dim());
|
||||||
|
swap(input_pattern, &mut temp_pattern);
|
||||||
|
let (mut offsets, mut indices) = temp_pattern.disassemble();
|
||||||
|
|
||||||
|
offsets.clear();
|
||||||
|
offsets.reserve(a.major_dim() + 1);
|
||||||
|
indices.clear();
|
||||||
|
|
||||||
|
offsets.push(0);
|
||||||
|
|
||||||
|
for lane_idx in 0 .. a.major_dim() {
|
||||||
|
let lane_a = a.lane(lane_idx);
|
||||||
|
let lane_b = b.lane(lane_idx);
|
||||||
|
indices.extend(iterate_intersection(lane_a, lane_b));
|
||||||
|
offsets.push(indices.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||||
|
let mut new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||||
|
a.major_dim(), a.minor_dim(), offsets, indices)
|
||||||
|
.expect("Pattern must be valid by definition");
|
||||||
|
swap(input_pattern, &mut new_pattern);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterate over the intersection of the two sets represented by sorted slices
|
||||||
|
/// (with unique elements)
|
||||||
|
fn iterate_intersection<'a>(mut sorted_a: &'a [usize],
|
||||||
|
mut sorted_b: &'a [usize]) -> impl Iterator<Item=usize> + 'a {
|
||||||
|
// TODO: Can use a kind of simultaneous exponential search to speed things up here
|
||||||
|
iter::from_fn(move || {
|
||||||
|
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
||||||
|
let item = if a_item < b_item {
|
||||||
|
sorted_a = &sorted_a[1 ..];
|
||||||
|
a_item
|
||||||
|
} else if b_item < a_item {
|
||||||
|
sorted_b = &sorted_b[1 ..];
|
||||||
|
b_item
|
||||||
|
} else {
|
||||||
|
// Both lists contain the same element, advance both slices to avoid
|
||||||
|
// duplicate entries in the result
|
||||||
|
sorted_a = &sorted_a[1 ..];
|
||||||
|
sorted_b = &sorted_b[1 ..];
|
||||||
|
a_item
|
||||||
|
};
|
||||||
|
Some(*item)
|
||||||
|
} else if let Some(a_item) = sorted_a.first() {
|
||||||
|
sorted_a = &sorted_a[1..];
|
||||||
|
Some(*a_item)
|
||||||
|
} else if let Some(b_item) = sorted_b.first() {
|
||||||
|
sorted_b = &sorted_b[1..];
|
||||||
|
Some(*b_item)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,8 +1,9 @@
|
||||||
use nalgebra_sparse::coo::CooMatrix;
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense};
|
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern};
|
||||||
use nalgebra_sparse::ops::{Transpose};
|
use nalgebra_sparse::ops::{Transpose};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::csr;
|
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||||
|
use nalgebra_sparse::pattern::SparsityPattern;
|
||||||
|
|
||||||
use nalgebra::{DVector, DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
use nalgebra::{DVector, DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
||||||
use nalgebra::proptest::matrix;
|
use nalgebra::proptest::matrix;
|
||||||
|
@ -10,6 +11,7 @@ use nalgebra::proptest::matrix;
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
|
|
||||||
use std::panic::catch_unwind;
|
use std::panic::catch_unwind;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmv_coo_agrees_with_dense_gemv() {
|
fn spmv_coo_agrees_with_dense_gemv() {
|
||||||
|
@ -99,6 +101,19 @@ fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
||||||
proptest::bool::ANY.prop_map(Transpose)
|
proptest::bool::ANY.prop_map(Transpose)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
||||||
|
sparsity_pattern(0 ..= 6usize, 0..= 6usize, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs pairs (a, b) where a and b have the same dimensions
|
||||||
|
fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||||
|
pattern_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), 40);
|
||||||
|
(Just(a), b)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to help us call dense GEMM with our transposition parameters
|
/// Helper function to help us call dense GEMM with our transposition parameters
|
||||||
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||||
beta: i32,
|
beta: i32,
|
||||||
|
@ -167,4 +182,26 @@ proptest! {
|
||||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spadd_build_pattern_test((c, (a, b)) in (pattern_strategy(), spadd_build_pattern_strategy()))
|
||||||
|
{
|
||||||
|
// (a, b) are dimensionally compatible patterns, whereas c is an *arbitrary* pattern
|
||||||
|
let mut pattern_result = c.clone();
|
||||||
|
spadd_build_pattern(&mut pattern_result, &a, &b);
|
||||||
|
|
||||||
|
// To verify the pattern, we construct CSR matrices with positive integer entries
|
||||||
|
// corresponding to a and b, and convert them to dense matrices.
|
||||||
|
// The sum of these dense matrices will then have non-zeros in exactly the same locations
|
||||||
|
// as the result of "adding" the sparsity patterns
|
||||||
|
let a_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(a.clone()), vec![1; a.nnz()])
|
||||||
|
.unwrap();
|
||||||
|
let a_dense = DMatrix::from(&a_csr);
|
||||||
|
let b_csr = CsrMatrix::try_from_pattern_and_values(Arc::new(b.clone()), vec![1; b.nnz()])
|
||||||
|
.unwrap();
|
||||||
|
let b_dense = DMatrix::from(&b_csr);
|
||||||
|
let c_dense = a_dense + b_dense;
|
||||||
|
let c_csr = CsrMatrix::from(&c_dense);
|
||||||
|
|
||||||
|
prop_assert_eq!(&pattern_result, &*c_csr.pattern());
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue