Implement CSR-CSR addition

This commit is contained in:
Andreas Longva 2020-12-10 13:30:37 +01:00
parent 921686c490
commit 41941e62c8
5 changed files with 278 additions and 4 deletions

View File

@ -0,0 +1,68 @@
use crate::csr::CsrMatrix;
use std::ops::Add;
use crate::ops::serial::{spadd_csr, spadd_build_pattern};
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
use num_traits::{Zero, One};
use std::sync::Arc;
use crate::ops::Transpose;
use crate::pattern::SparsityPattern;
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
where
// TODO: Consider introducing wrapper trait for these things? It's technically a "Ring",
// I guess...
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = CsrMatrix<T>;
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
let mut pattern = SparsityPattern::new(self.nrows(), self.ncols());
spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern());
let values = vec![T::zero(); pattern.nnz()];
// We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
.unwrap();
spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap();
spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap();
result
}
}
impl<'a, T> Add<&'a CsrMatrix<T>> for CsrMatrix<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = CsrMatrix<T>;
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap();
self
} else {
&self + rhs
}
}
}
impl<'a, T> Add<CsrMatrix<T>> for &'a CsrMatrix<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = CsrMatrix<T>;
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
rhs + self
}
}
impl<T> Add<CsrMatrix<T>> for CsrMatrix<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = Self;
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
self + &rhs
}
}

View File

@ -1,5 +1,6 @@
//! TODO
mod impl_std_ops;
pub mod serial;
/// TODO
@ -11,4 +12,6 @@ impl Transpose {
pub fn to_bool(&self) -> bool {
self.0
}
}
}

View File

@ -2,6 +2,9 @@ use crate::csr::CsrMatrix;
use crate::ops::{Transpose};
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
use num_traits::{Zero, One};
use crate::ops::serial::{OperationError, OperationErrorType};
use std::sync::Arc;
use crate::SparseEntryMut;
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
@ -65,4 +68,90 @@ where
}
}
}
}
}
fn spadd_csr_unexpected_entry() -> OperationError {
OperationError::from_type_and_message(
OperationErrorType::InvalidPattern,
String::from("Found entry in `a` that is not present in `c`."))
}
/// Sparse matrix addition `C <- beta * C + alpha * trans(A)`.
///
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
/// returned.
pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
beta: T,
alpha: T,
trans_a: Transpose,
a: &CsrMatrix<T>)
-> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
// TODO: Proper error messages
if trans_a.to_bool() {
assert_eq!(c.nrows(), a.ncols());
assert_eq!(c.ncols(), a.nrows());
} else {
assert_eq!(c.nrows(), a.nrows());
assert_eq!(c.ncols(), a.ncols());
}
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
// so we only need to sum the value arrays
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.values()) {
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
}
Ok(())
} else {
if trans_a.to_bool()
{
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_row_i) in a.row_iter().enumerate() {
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.index_entry_mut(j, i) {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
}
}
}
} else {
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
if beta != T::one() {
for c_ij in c_row_i.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
// TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C.
let (c_idx, _) = c_cols.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_csr_unexpected_entry)?;
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
c_cols = &c_cols[c_idx ..];
c_vals = &mut c_vals[c_idx ..];
}
}
}
Ok(())
}
}

View File

@ -36,4 +36,26 @@ mod pattern;
pub use coo::*;
pub use csr::*;
pub use pattern::*;
pub use pattern::*;
/// TODO
#[derive(Clone, Debug)]
pub struct OperationError {
error_type: OperationErrorType,
message: String
}
/// TODO
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum OperationErrorType {
/// TODO
InvalidPattern,
}
impl OperationError {
/// TODO
pub fn from_type_and_message(error_type: OperationErrorType, message: String) -> Self {
Self { error_type, message }
}
}

View File

@ -1,5 +1,5 @@
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern};
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern, spadd_csr};
use nalgebra_sparse::ops::{Transpose};
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
@ -15,6 +15,15 @@ use std::sync::Arc;
use crate::common::csr_strategy;
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
Arc::new(pattern.clone()),
vec![1; pattern.nnz()])
.unwrap();
DMatrix::from(&boolean_csr)
}
#[test]
fn spmv_coo_agrees_with_dense_gemv() {
let x = DVector::from_column_slice(&[2, 3, 4, 5]);
@ -91,6 +100,37 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
})
}
#[derive(Debug)]
struct SpaddCsrArgs<T> {
c: CsrMatrix<T>,
beta: T,
alpha: T,
trans_a: Transpose,
a: CsrMatrix<T>,
}
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
let value_strategy = -5 ..= 5;
// TODO :Support transposition
spadd_build_pattern_strategy()
.prop_flat_map(move |(a_pattern, b_pattern)| {
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern);
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
let alpha = value_strategy.clone();
let beta = value_strategy.clone();
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
let a = if trans_a.to_bool() { a.transpose() } else { a };
SpaddCsrArgs { c, beta, alpha, trans_a, a }
})
}
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
@ -203,4 +243,56 @@ proptest! {
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref());
}
#[test]
fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, trans_a, a } in spadd_csr_args_strategy()) {
// Test that we get the expected result by comparing to an equivalent dense operation
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
let mut c_sparse = c.clone();
spadd_csr(&mut c_sparse, beta, alpha, trans_a, &a).unwrap();
let mut c_dense = DMatrix::from(&c);
let op_a_dense = DMatrix::from(&a);
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
c_dense = beta * c_dense + alpha * &op_a_dense;
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
}
#[test]
fn csr_add_csr(
// a and b have the same dimensions
(a, b)
in csr_strategy()
.prop_flat_map(|a| {
let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40);
(Just(a), b)
}))
{
// We use the dense result as the ground truth for the arithmetic result
let c_dense = DMatrix::from(&a) + DMatrix::from(&b);
// However, it's not enough only to cover the dense result, we also need to verify the
// sparsity pattern. We can determine the exact sparsity pattern by using
// dense arithmetic with positive integer values and extracting positive entries.
let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern());
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
// Check each combination of owned matrices and references
let c_owned_owned = a.clone() + b.clone();
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
let c_owned_ref = a.clone() + &b;
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
let c_ref_owned = &a + b.clone();
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
let c_ref_ref = &a + &b;
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
}