Implement CSR-CSR addition
This commit is contained in:
parent
921686c490
commit
41941e62c8
|
@ -0,0 +1,68 @@
|
||||||
|
use crate::csr::CsrMatrix;
|
||||||
|
|
||||||
|
use std::ops::Add;
|
||||||
|
use crate::ops::serial::{spadd_csr, spadd_build_pattern};
|
||||||
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
|
use num_traits::{Zero, One};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::ops::Transpose;
|
||||||
|
use crate::pattern::SparsityPattern;
|
||||||
|
|
||||||
|
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||||
|
where
|
||||||
|
// TODO: Consider introducing wrapper trait for these things? It's technically a "Ring",
|
||||||
|
// I guess...
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
type Output = CsrMatrix<T>;
|
||||||
|
|
||||||
|
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||||
|
let mut pattern = SparsityPattern::new(self.nrows(), self.ncols());
|
||||||
|
spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern());
|
||||||
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
|
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
|
.unwrap();
|
||||||
|
spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap();
|
||||||
|
spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap();
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Add<&'a CsrMatrix<T>> for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
type Output = CsrMatrix<T>;
|
||||||
|
|
||||||
|
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||||
|
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
||||||
|
spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap();
|
||||||
|
self
|
||||||
|
} else {
|
||||||
|
&self + rhs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> Add<CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
type Output = CsrMatrix<T>;
|
||||||
|
|
||||||
|
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
||||||
|
rhs + self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Add<CsrMatrix<T>> for CsrMatrix<T>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
||||||
|
self + &rhs
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
//! TODO
|
//! TODO
|
||||||
|
|
||||||
|
mod impl_std_ops;
|
||||||
pub mod serial;
|
pub mod serial;
|
||||||
|
|
||||||
/// TODO
|
/// TODO
|
||||||
|
@ -12,3 +13,5 @@ impl Transpose {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,9 @@ use crate::csr::CsrMatrix;
|
||||||
use crate::ops::{Transpose};
|
use crate::ops::{Transpose};
|
||||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
|
use crate::ops::serial::{OperationError, OperationErrorType};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::SparseEntryMut;
|
||||||
|
|
||||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
||||||
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
|
@ -66,3 +69,89 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn spadd_csr_unexpected_entry() -> OperationError {
|
||||||
|
OperationError::from_type_and_message(
|
||||||
|
OperationErrorType::InvalidPattern,
|
||||||
|
String::from("Found entry in `a` that is not present in `c`."))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix addition `C <- beta * C + alpha * trans(A)`.
|
||||||
|
///
|
||||||
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||||
|
/// returned.
|
||||||
|
pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transpose,
|
||||||
|
a: &CsrMatrix<T>)
|
||||||
|
-> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
// TODO: Proper error messages
|
||||||
|
if trans_a.to_bool() {
|
||||||
|
assert_eq!(c.nrows(), a.ncols());
|
||||||
|
assert_eq!(c.ncols(), a.nrows());
|
||||||
|
} else {
|
||||||
|
assert_eq!(c.nrows(), a.nrows());
|
||||||
|
assert_eq!(c.ncols(), a.ncols());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
||||||
|
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
|
||||||
|
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
||||||
|
// so we only need to sum the value arrays
|
||||||
|
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.values()) {
|
||||||
|
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
||||||
|
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
if trans_a.to_bool()
|
||||||
|
{
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, a_row_i) in a.row_iter().enumerate() {
|
||||||
|
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let a_val = a_val.inlined_clone();
|
||||||
|
let alpha = alpha.inlined_clone();
|
||||||
|
match c.index_entry_mut(j, i) {
|
||||||
|
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
||||||
|
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c_row_i.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
|
||||||
|
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
|
||||||
|
|
||||||
|
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
|
||||||
|
// TODO: Use exponential search instead of linear search.
|
||||||
|
// If C has substantially more entries in the row than A, then a line search
|
||||||
|
// will needlessly visit many entries in C.
|
||||||
|
let (c_idx, _) = c_cols.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == a_col)
|
||||||
|
.ok_or_else(spadd_csr_unexpected_entry)?;
|
||||||
|
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||||
|
c_cols = &c_cols[c_idx ..];
|
||||||
|
c_vals = &mut c_vals[c_idx ..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,3 +37,25 @@ mod pattern;
|
||||||
pub use coo::*;
|
pub use coo::*;
|
||||||
pub use csr::*;
|
pub use csr::*;
|
||||||
pub use pattern::*;
|
pub use pattern::*;
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct OperationError {
|
||||||
|
error_type: OperationErrorType,
|
||||||
|
message: String
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum OperationErrorType {
|
||||||
|
/// TODO
|
||||||
|
InvalidPattern,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OperationError {
|
||||||
|
/// TODO
|
||||||
|
pub fn from_type_and_message(error_type: OperationErrorType, message: String) -> Self {
|
||||||
|
Self { error_type, message }
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
use nalgebra_sparse::coo::CooMatrix;
|
use nalgebra_sparse::coo::CooMatrix;
|
||||||
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern};
|
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern, spadd_csr};
|
||||||
use nalgebra_sparse::ops::{Transpose};
|
use nalgebra_sparse::ops::{Transpose};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||||
|
@ -15,6 +15,15 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use crate::common::csr_strategy;
|
use crate::common::csr_strategy;
|
||||||
|
|
||||||
|
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||||
|
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
|
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||||
|
Arc::new(pattern.clone()),
|
||||||
|
vec![1; pattern.nnz()])
|
||||||
|
.unwrap();
|
||||||
|
DMatrix::from(&boolean_csr)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmv_coo_agrees_with_dense_gemv() {
|
fn spmv_coo_agrees_with_dense_gemv() {
|
||||||
let x = DVector::from_column_slice(&[2, 3, 4, 5]);
|
let x = DVector::from_column_slice(&[2, 3, 4, 5]);
|
||||||
|
@ -91,6 +100,37 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SpaddCsrArgs<T> {
|
||||||
|
c: CsrMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
trans_a: Transpose,
|
||||||
|
a: CsrMatrix<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
|
let value_strategy = -5 ..= 5;
|
||||||
|
|
||||||
|
// TODO :Support transposition
|
||||||
|
spadd_build_pattern_strategy()
|
||||||
|
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||||
|
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
||||||
|
spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern);
|
||||||
|
|
||||||
|
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
|
||||||
|
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
||||||
|
let alpha = value_strategy.clone();
|
||||||
|
let beta = value_strategy.clone();
|
||||||
|
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
||||||
|
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||||
|
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
|
||||||
|
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
|
||||||
|
|
||||||
|
let a = if trans_a.to_bool() { a.transpose() } else { a };
|
||||||
|
SpaddCsrArgs { c, beta, alpha, trans_a, a }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||||
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
||||||
|
@ -203,4 +243,56 @@ proptest! {
|
||||||
|
|
||||||
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref());
|
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, trans_a, a } in spadd_csr_args_strategy()) {
|
||||||
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
|
|
||||||
|
let mut c_sparse = c.clone();
|
||||||
|
spadd_csr(&mut c_sparse, beta, alpha, trans_a, &a).unwrap();
|
||||||
|
|
||||||
|
let mut c_dense = DMatrix::from(&c);
|
||||||
|
let op_a_dense = DMatrix::from(&a);
|
||||||
|
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
|
||||||
|
c_dense = beta * c_dense + alpha * &op_a_dense;
|
||||||
|
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csr_add_csr(
|
||||||
|
// a and b have the same dimensions
|
||||||
|
(a, b)
|
||||||
|
in csr_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// We use the dense result as the ground truth for the arithmetic result
|
||||||
|
let c_dense = DMatrix::from(&a) + DMatrix::from(&b);
|
||||||
|
// However, it's not enough only to cover the dense result, we also need to verify the
|
||||||
|
// sparsity pattern. We can determine the exact sparsity pattern by using
|
||||||
|
// dense arithmetic with positive integer values and extracting positive entries.
|
||||||
|
let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern());
|
||||||
|
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
|
||||||
|
|
||||||
|
// Check each combination of owned matrices and references
|
||||||
|
let c_owned_owned = a.clone() + b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_owned_ref = a.clone() + &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_owned = &a + b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_ref = &a + &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue