diff --git a/nalgebra-sparse/src/ops/impl_std_ops.rs b/nalgebra-sparse/src/ops/impl_std_ops.rs index e69de29b..c8e9e800 100644 --- a/nalgebra-sparse/src/ops/impl_std_ops.rs +++ b/nalgebra-sparse/src/ops/impl_std_ops.rs @@ -0,0 +1,68 @@ +use crate::csr::CsrMatrix; + +use std::ops::Add; +use crate::ops::serial::{spadd_csr, spadd_build_pattern}; +use nalgebra::{ClosedAdd, ClosedMul, Scalar}; +use num_traits::{Zero, One}; +use std::sync::Arc; +use crate::ops::Transpose; +use crate::pattern::SparsityPattern; + +impl<'a, T> Add<&'a CsrMatrix> for &'a CsrMatrix +where + // TODO: Consider introducing wrapper trait for these things? It's technically a "Ring", + // I guess... + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + type Output = CsrMatrix; + + fn add(self, rhs: &'a CsrMatrix) -> Self::Output { + let mut pattern = SparsityPattern::new(self.nrows(), self.ncols()); + spadd_build_pattern(&mut pattern, self.pattern(), rhs.pattern()); + let values = vec![T::zero(); pattern.nnz()]; + // We are giving data that is valid by definition, so it is safe to unwrap below + let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) + .unwrap(); + spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap(); + spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap(); + result + } +} + +impl<'a, T> Add<&'a CsrMatrix> for CsrMatrix +where + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + type Output = CsrMatrix; + + fn add(mut self, rhs: &'a CsrMatrix) -> Self::Output { + if Arc::ptr_eq(self.pattern(), rhs.pattern()) { + spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap(); + self + } else { + &self + rhs + } + } +} + +impl<'a, T> Add> for &'a CsrMatrix + where + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + type Output = CsrMatrix; + + fn add(self, rhs: CsrMatrix) -> Self::Output { + rhs + self + } +} + +impl Add> for CsrMatrix +where + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + type Output = Self; + + fn add(self, rhs: CsrMatrix) -> Self::Output { + self + &rhs + } +} \ No newline at end of file diff --git a/nalgebra-sparse/src/ops/mod.rs b/nalgebra-sparse/src/ops/mod.rs index bf1698ec..08939d8a 100644 --- a/nalgebra-sparse/src/ops/mod.rs +++ b/nalgebra-sparse/src/ops/mod.rs @@ -1,5 +1,6 @@ //! TODO +mod impl_std_ops; pub mod serial; /// TODO @@ -11,4 +12,6 @@ impl Transpose { pub fn to_bool(&self) -> bool { self.0 } -} \ No newline at end of file +} + + diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index b77d1112..42ef6121 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -2,6 +2,9 @@ use crate::csr::CsrMatrix; use crate::ops::{Transpose}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use num_traits::{Zero, One}; +use crate::ops::serial::{OperationError, OperationErrorType}; +use std::sync::Arc; +use crate::SparseEntryMut; /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`. pub fn spmm_csr_dense<'a, T>(c: impl Into>, @@ -65,4 +68,90 @@ where } } } -} \ No newline at end of file +} + +fn spadd_csr_unexpected_entry() -> OperationError { + OperationError::from_type_and_message( + OperationErrorType::InvalidPattern, + String::from("Found entry in `a` that is not present in `c`.")) +} + +/// Sparse matrix addition `C <- beta * C + alpha * trans(A)`. +/// +/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is +/// returned. +pub fn spadd_csr(c: &mut CsrMatrix, + beta: T, + alpha: T, + trans_a: Transpose, + a: &CsrMatrix) + -> Result<(), OperationError> +where + T: Scalar + ClosedAdd + ClosedMul + Zero + One +{ + // TODO: Proper error messages + if trans_a.to_bool() { + assert_eq!(c.nrows(), a.ncols()); + assert_eq!(c.ncols(), a.nrows()); + } else { + assert_eq!(c.nrows(), a.nrows()); + assert_eq!(c.ncols(), a.ncols()); + } + + // TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc` + if Arc::ptr_eq(&c.pattern(), &a.pattern()) { + // Special fast path: The two matrices have *exactly* the same sparsity pattern, + // so we only need to sum the value arrays + for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.values()) { + let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone()); + *c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone(); + } + Ok(()) + } else { + if trans_a.to_bool() + { + if beta != T::one() { + for c_ij in c.values_mut() { + *c_ij *= beta.inlined_clone(); + } + } + + for (i, a_row_i) in a.row_iter().enumerate() { + for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) { + let a_val = a_val.inlined_clone(); + let alpha = alpha.inlined_clone(); + match c.index_entry_mut(j, i) { + SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val } + SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()), + } + } + } + } else { + for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { + if beta != T::one() { + for c_ij in c_row_i.values_mut() { + *c_ij *= beta.inlined_clone(); + } + } + + let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut(); + let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values()); + + for (a_col, a_val) in a_cols.iter().zip(a_vals) { + // TODO: Use exponential search instead of linear search. + // If C has substantially more entries in the row than A, then a line search + // will needlessly visit many entries in C. + let (c_idx, _) = c_cols.iter() + .enumerate() + .find(|(_, c_col)| *c_col == a_col) + .ok_or_else(spadd_csr_unexpected_entry)?; + c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone(); + c_cols = &c_cols[c_idx ..]; + c_vals = &mut c_vals[c_idx ..]; + } + } + } + Ok(()) + } +} + diff --git a/nalgebra-sparse/src/ops/serial/mod.rs b/nalgebra-sparse/src/ops/serial/mod.rs index bb40419f..cd0dc09c 100644 --- a/nalgebra-sparse/src/ops/serial/mod.rs +++ b/nalgebra-sparse/src/ops/serial/mod.rs @@ -36,4 +36,26 @@ mod pattern; pub use coo::*; pub use csr::*; -pub use pattern::*; \ No newline at end of file +pub use pattern::*; + +/// TODO +#[derive(Clone, Debug)] +pub struct OperationError { + error_type: OperationErrorType, + message: String +} + +/// TODO +#[non_exhaustive] +#[derive(Clone, Debug)] +pub enum OperationErrorType { + /// TODO + InvalidPattern, +} + +impl OperationError { + /// TODO + pub fn from_type_and_message(error_type: OperationErrorType, message: String) -> Self { + Self { error_type, message } + } +} \ No newline at end of file diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index cc416790..6c4f2006 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -1,5 +1,5 @@ use nalgebra_sparse::coo::CooMatrix; -use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern}; +use nalgebra_sparse::ops::serial::{spmv_coo, spmm_csr_dense, spadd_build_pattern, spadd_csr}; use nalgebra_sparse::ops::{Transpose}; use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::proptest::{csr, sparsity_pattern}; @@ -15,6 +15,15 @@ use std::sync::Arc; use crate::common::csr_strategy; +/// Represents the sparsity pattern of a CSR matrix as a dense matrix with 0/1 +fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix { + let boolean_csr = CsrMatrix::try_from_pattern_and_values( + Arc::new(pattern.clone()), + vec![1; pattern.nnz()]) + .unwrap(); + DMatrix::from(&boolean_csr) +} + #[test] fn spmv_coo_agrees_with_dense_gemv() { let x = DVector::from_column_slice(&[2, 3, 4, 5]); @@ -91,6 +100,37 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy> }) } +#[derive(Debug)] +struct SpaddCsrArgs { + c: CsrMatrix, + beta: T, + alpha: T, + trans_a: Transpose, + a: CsrMatrix, +} + +fn spadd_csr_args_strategy() -> impl Strategy> { + let value_strategy = -5 ..= 5; + + // TODO :Support transposition + spadd_build_pattern_strategy() + .prop_flat_map(move |(a_pattern, b_pattern)| { + let mut c_pattern = SparsityPattern::new(a_pattern.major_dim(), b_pattern.major_dim()); + spadd_build_pattern(&mut c_pattern, &a_pattern, &b_pattern); + + let a_values = vec![value_strategy.clone(); a_pattern.nnz()]; + let c_values = vec![value_strategy.clone(); c_pattern.nnz()]; + let alpha = value_strategy.clone(); + let beta = value_strategy.clone(); + (Just(c_pattern), Just(a_pattern), c_values, a_values, alpha, beta, trans_strategy()) + }).prop_map(|(c_pattern, a_pattern, c_values, a_values, alpha, beta, trans_a)| { + let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap(); + let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap(); + + let a = if trans_a.to_bool() { a.transpose() } else { a }; + SpaddCsrArgs { c, beta, alpha, trans_a, a } + }) +} fn dense_strategy() -> impl Strategy> { matrix(-5 ..= 5, 0 ..= 6, 0 ..= 6) @@ -203,4 +243,56 @@ proptest! { prop_assert_eq!(&pattern_result, c_csr.pattern().as_ref()); } + + #[test] + fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, trans_a, a } in spadd_csr_args_strategy()) { + // Test that we get the expected result by comparing to an equivalent dense operation + // (here we give in the C matrix, so the sparsity pattern is essentially fixed) + + let mut c_sparse = c.clone(); + spadd_csr(&mut c_sparse, beta, alpha, trans_a, &a).unwrap(); + + let mut c_dense = DMatrix::from(&c); + let op_a_dense = DMatrix::from(&a); + let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense }; + c_dense = beta * c_dense + alpha * &op_a_dense; + + prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense); + } + + #[test] + fn csr_add_csr( + // a and b have the same dimensions + (a, b) + in csr_strategy() + .prop_flat_map(|a| { + let b = csr(-5 ..= 5, Just(a.nrows()), Just(a.ncols()), 40); + (Just(a), b) + })) + { + // We use the dense result as the ground truth for the arithmetic result + let c_dense = DMatrix::from(&a) + DMatrix::from(&b); + // However, it's not enough only to cover the dense result, we also need to verify the + // sparsity pattern. We can determine the exact sparsity pattern by using + // dense arithmetic with positive integer values and extracting positive entries. + let c_dense_pattern = dense_csr_pattern(a.pattern()) + dense_csr_pattern(b.pattern()); + let c_pattern = CsrMatrix::from(&c_dense_pattern).pattern().clone(); + + // Check each combination of owned matrices and references + let c_owned_owned = a.clone() + b.clone(); + prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense); + prop_assert_eq!(c_owned_owned.pattern(), &c_pattern); + + let c_owned_ref = a.clone() + &b; + prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense); + prop_assert_eq!(c_owned_ref.pattern(), &c_pattern); + + let c_ref_owned = &a + b.clone(); + prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense); + prop_assert_eq!(c_ref_owned.pattern(), &c_pattern); + + let c_ref_ref = &a + &b; + prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); + prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); + } } \ No newline at end of file