forked from M-Labs/nalgebra
Implement CSR-CSR addition
This commit is contained in:
parent
921686c490
commit
41941e62c8
@ -0,0 +1,68 @@
|
||||
use crate::csr::CsrMatrix;
|
||||
|
||||
use std::ops::Add;
|
||||
use crate::ops::serial::{spadd_csr, spadd_build_pattern};
|
||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||
use num_traits::{Zero, One};
|
||||
use std::sync::Arc;
|
||||
use crate::ops::Transpose;
|
||||
use crate::pattern::SparsityPattern;
|
||||
|
||||
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||
where
|
||||
// TODO: Consider introducing wrapper trait for these things? It's technically a "Ring",
|
||||
// I guess...
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
type Output = CsrMatrix<T>;
|
||||
|
||||
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||
let mut pattern = SparsityPattern::new(self.nrows(), self.ncols());
|
||||
spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||
.unwrap();
|
||||
spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap();
|
||||
spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap();
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Add<&'a CsrMatrix<T>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
type Output = CsrMatrix<T>;
|
||||
|
||||
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
||||
spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap();
|
||||
self
|
||||
} else {
|
||||
&self + rhs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Add<CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
type Output = CsrMatrix<T>;
|
||||
|
||||
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Add<CsrMatrix<T>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
||||
self + &rhs
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
//! TODO
|
||||
|
||||
mod impl_std_ops;
|
||||
pub mod serial;
|
||||
|
||||
/// TODO
|
||||
@ -12,3 +13,5 @@ impl Transpose {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -2,6 +2,9 @@ use crate::csr::CsrMatrix;
|
||||
use crate::ops::{Transpose};
|
||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||
use num_traits::{Zero, One};
|
||||
use crate::ops::serial::{OperationError, OperationErrorType};
|
||||
use std::sync::Arc;
|
||||
use crate::SparseEntryMut;
|
||||
|
||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
||||
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
@ -66,3 +69,89 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn spadd_csr_unexpected_entry() -> OperationError {
|
||||
OperationError::from_type_and_message(
|
||||
OperationErrorType::InvalidPattern,
|
||||
String::from("Found entry in `a` that is not present in `c`."))
|
||||
}
|
||||
|
||||
/// Sparse matrix addition `C <- beta * C + alpha * trans(A)`.
|
||||
///
|
||||
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||
/// returned.
|
||||
pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
|
||||
beta: T,
|
||||
alpha: T,
|
||||
trans_a: Transpose,
|
||||
a: &CsrMatrix<T>)
|
||||
-> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
// TODO: Proper error messages
|
||||
if trans_a.to_bool() {
|
||||
assert_eq!(c.nrows(), a.ncols());
|
||||
assert_eq!(c.ncols(), a.nrows());
|
||||
} else {
|
||||
assert_eq!(c.nrows(), a.nrows());
|
||||
assert_eq!(c.ncols(), a.ncols());
|
||||
}
|
||||
|
||||
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
||||
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
|
||||
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
||||
// so we only need to sum the value arrays
|
||||
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.values()) {
|
||||
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
||||
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
if trans_a.to_bool()
|
||||
{
|
||||
if beta != T::one() {
|
||||
for c_ij in c.values_mut() {
|
||||
*c_ij *= beta.inlined_clone();
|
||||
}
|
||||
}
|
||||
|
||||
for (i, a_row_i) in a.row_iter().enumerate() {
|
||||
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||
let a_val = a_val.inlined_clone();
|
||||
let alpha = alpha.inlined_clone();
|
||||
match c.index_entry_mut(j, i) {
|
||||
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
||||
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||
if beta != T::one() {
|
||||
for c_ij in c_row_i.values_mut() {
|
||||
*c_ij *= beta.inlined_clone();
|
||||
}
|
||||
}
|
||||
|
||||
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
|
||||
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
|
||||
|
||||
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
|
||||
// TODO: Use exponential search instead of linear search.
|
||||
// If C has substantially more entries in the row than A, then a line search
|
||||
// will needlessly visit many entries in C.
|
||||
let (c_idx, _) = c_cols.iter()
|
||||
.enumerate()
|
||||
.find(|(_, c_col)| *c_col == a_col)
|
||||
.ok_or_else(spadd_csr_unexpected_entry)?;
|
||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||
c_cols = &c_cols[c_idx ..];
|
||||
c_vals = &mut c_vals[c_idx ..];
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -37,3 +37,25 @@ mod pattern;
|
||||
pub use coo::*;
|
||||
pub use csr::*;
|
||||
pub use pattern::*;
|
||||
|
||||
/// TODO
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationError {
|
||||
error_type: OperationErrorType,
|
||||
message: String
|
||||
}
|
||||
|
||||
/// TODO
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OperationErrorType {
|
||||
/// TODO
|
||||
InvalidPattern,
|
||||
}
|
||||
|
||||
impl OperationError {
|
||||
/// TODO
|
||||
pub fn from_type_and_message(error_type: OperationErrorType, message: String) -> Self {
|
||||
Self { error_type, message }
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern};
|
||||
use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern, spadd_csr};
|
||||
use nalgebra_sparse::ops::{Transpose};
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||
@ -15,6 +15,15 @@ use std::sync::Arc;
|
||||
|
||||
use crate::common::csr_strategy;
|
||||
|
||||
/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1
|
||||
fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||
let boolean_csr = CsrMatrix::try_from_pattern_and_values(
|
||||
Arc::new(pattern.clone()),
|
||||
vec![1; pattern.nnz()])
|
||||
.unwrap();
|
||||
DMatrix::from(&boolean_csr)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spmv_coo_agrees_with_dense_gemv() {
|
||||
let x = DVector::from_column_slice(&[2, 3, 4, 5]);
|
||||
@ -91,6 +100,37 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SpaddCsrArgs<T> {
|
||||
c: CsrMatrix<T>,
|
||||
beta: T,
|
||||
alpha: T,
|
||||
trans_a: Transpose,
|
||||
a: CsrMatrix<T>,
|
||||
}
|
||||
|
||||
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||
let value_strategy = -5 ..= 5;
|
||||
|
||||
// TODO :Support transposition
|
||||
spadd_build_pattern_strategy()
|
||||
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||
let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim());
|
||||
spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern);
|
||||
|
||||
let a_values = vec![value_strategy.clone(); a_pattern.nnz()];
|
||||
let c_values = vec![value_strategy.clone(); c_pattern.nnz()];
|
||||
let alpha = value_strategy.clone();
|
||||
let beta = value_strategy.clone();
|
||||
(Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy())
|
||||
}).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| {
|
||||
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
|
||||
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
|
||||
|
||||
let a = if trans_a.to_bool() { a.transpose() } else { a };
|
||||
SpaddCsrArgs { c, beta, alpha, trans_a, a }
|
||||
})
|
||||
}
|
||||
|
||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||
matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6)
|
||||
@ -203,4 +243,56 @@ proptest! {
|
||||
|
||||
prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, trans_a, a } in spadd_csr_args_strategy()) {
|
||||
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||
|
||||
let mut c_sparse = c.clone();
|
||||
spadd_csr(&mut c_sparse, beta, alpha, trans_a, &a).unwrap();
|
||||
|
||||
let mut c_dense = DMatrix::from(&c);
|
||||
let op_a_dense = DMatrix::from(&a);
|
||||
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
|
||||
c_dense = beta * c_dense + alpha * &op_a_dense;
|
||||
|
||||
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_add_csr(
|
||||
// a and b have the same dimensions
|
||||
(a, b)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|a| {
|
||||
let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40);
|
||||
(Just(a), b)
|
||||
}))
|
||||
{
|
||||
// We use the dense result as the ground truth for the arithmetic result
|
||||
let c_dense = DMatrix::from(&a) + DMatrix::from(&b);
|
||||
// However, it's not enough only to cover the dense result, we also need to verify the
|
||||
// sparsity pattern. We can determine the exact sparsity pattern by using
|
||||
// dense arithmetic with positive integer values and extracting positive entries.
|
||||
let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern());
|
||||
let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone();
|
||||
|
||||
// Check each combination of owned matrices and references
|
||||
let c_owned_owned = a.clone() + b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_owned_ref = a.clone() + &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_owned = &a + b.clone();
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||
|
||||
let c_ref_ref = &a + &b;
|
||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user