Implement arithmetic operations for CSC matrices

This commit is contained in:
Andreas Longva 2020-12-30 16:09:46 +01:00
parent 6a1d12705f
commit dbdf5567fc
12 changed files with 838 additions and 275 deletions

View File

@ -144,9 +144,19 @@ impl<T> CsMatrix<T> {
values: &mut values[range] values: &mut values[range]
}) })
} }
#[inline]
pub fn lane_iter(&self) -> CsLaneIter<T> {
CsLaneIter::new(self.pattern().as_ref(), self.values())
} }
pub fn get_entry_from_slices<'a, T>( #[inline]
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values)
}
}
fn get_entry_from_slices<'a, T>(
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],
values: &'a [T], values: &'a [T],
@ -161,7 +171,7 @@ pub fn get_entry_from_slices<'a, T>(
} }
} }
pub fn get_mut_entry_from_slices<'a, T>( fn get_mut_entry_from_slices<'a, T>(
minor_dim: usize, minor_dim: usize,
minor_indices: &'a [usize], minor_indices: &'a [usize],
values: &'a mut [T], values: &'a mut [T],
@ -178,16 +188,16 @@ pub fn get_mut_entry_from_slices<'a, T>(
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsLane<'a, T> { pub struct CsLane<'a, T> {
pub minor_dim: usize, minor_dim: usize,
pub minor_indices: &'a [usize], minor_indices: &'a [usize],
pub values: &'a [T] values: &'a [T]
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct CsLaneMut<'a, T> { pub struct CsLaneMut<'a, T> {
pub minor_dim: usize, minor_dim: usize,
pub minor_indices: &'a [usize], minor_indices: &'a [usize],
pub values: &'a mut [T] values: &'a mut [T]
} }
pub struct CsLaneIter<'a, T> { pub struct CsLaneIter<'a, T> {
@ -280,4 +290,59 @@ impl<'a, T> Iterator for CsLaneIterMut<'a, T>
} }
} }
/// Implement the methods common to both CsLane and CsLaneMut. See the documentation for the
/// methods delegated here by CsrMatrix and CscMatrix members for more information.
macro_rules! impl_cs_lane_common_methods {
($name:ty) => {
impl<'a, T> $name {
#[inline]
pub fn minor_dim(&self) -> usize {
self.minor_dim
}
#[inline]
pub fn nnz(&self) -> usize {
self.minor_indices.len()
}
#[inline]
pub fn minor_indices(&self) -> &[usize] {
self.minor_indices
}
#[inline]
pub fn values(&self) -> &[T] {
self.values
}
#[inline]
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
get_entry_from_slices(
self.minor_dim,
self.minor_indices,
self.values,
global_col_index)
}
}
}
}
impl_cs_lane_common_methods!(CsLane<'a, T>);
impl_cs_lane_common_methods!(CsLaneMut<'a, T>);
impl<'a, T> CsLaneMut<'a, T> {
pub fn values_mut(&mut self) -> &mut [T] {
self.values
}
pub fn indices_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.minor_indices, self.values)
}
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.minor_dim,
self.minor_indices,
self.values,
global_minor_index)
}
}

View File

@ -3,8 +3,7 @@
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut, use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
get_entry_from_slices, get_mut_entry_from_slices};
use std::sync::Arc; use std::sync::Arc;
use std::slice::{IterMut, Iter}; use std::slice::{IterMut, Iter};
@ -21,7 +20,7 @@ use nalgebra::Scalar;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscMatrix<T> { pub struct CscMatrix<T> {
// Cols are major, rows are minor in the sparsity pattern // Cols are major, rows are minor in the sparsity pattern
cs: CsMatrix<T>, pub(crate) cs: CsMatrix<T>,
} }
impl<T> CscMatrix<T> { impl<T> CscMatrix<T> {
@ -435,25 +434,25 @@ macro_rules! impl_csc_col_common_methods {
/// The number of global rows in the column. /// The number of global rows in the column.
#[inline] #[inline]
pub fn nrows(&self) -> usize { pub fn nrows(&self) -> usize {
self.lane.minor_dim self.lane.minor_dim()
} }
/// The number of non-zeros in this column. /// The number of non-zeros in this column.
#[inline] #[inline]
pub fn nnz(&self) -> usize { pub fn nnz(&self) -> usize {
self.lane.minor_indices.len() self.lane.nnz()
} }
/// The row indices corresponding to explicitly stored entries in this column. /// The row indices corresponding to explicitly stored entries in this column.
#[inline] #[inline]
pub fn row_indices(&self) -> &[usize] { pub fn row_indices(&self) -> &[usize] {
self.lane.minor_indices self.lane.minor_indices()
} }
/// The values corresponding to explicitly stored entries in this column. /// The values corresponding to explicitly stored entries in this column.
#[inline] #[inline]
pub fn values(&self) -> &[T] { pub fn values(&self) -> &[T] {
self.lane.values self.lane.values()
} }
/// Returns an entry for the given global row index. /// Returns an entry for the given global row index.
@ -461,11 +460,7 @@ macro_rules! impl_csc_col_common_methods {
/// Each call to this function incurs the cost of a binary search among the explicitly /// Each call to this function incurs the cost of a binary search among the explicitly
/// stored row entries. /// stored row entries.
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> { pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> {
get_entry_from_slices( self.lane.get_entry(global_row_index)
self.lane.minor_dim,
self.lane.minor_indices,
self.lane.values,
global_row_index)
} }
} }
} }
@ -477,7 +472,7 @@ impl_csc_col_common_methods!(CscColMut<'a, T>);
impl<'a, T> CscColMut<'a, T> { impl<'a, T> CscColMut<'a, T> {
/// Mutable access to the values corresponding to explicitly stored entries in this column. /// Mutable access to the values corresponding to explicitly stored entries in this column.
pub fn values_mut(&mut self) -> &mut [T] { pub fn values_mut(&mut self) -> &mut [T] {
self.lane.values self.lane.values_mut()
} }
/// Provides simultaneous access to row indices and mutable values corresponding to the /// Provides simultaneous access to row indices and mutable values corresponding to the
@ -486,15 +481,12 @@ impl<'a, T> CscColMut<'a, T> {
/// This method primarily facilitates low-level access for methods that process data stored /// This method primarily facilitates low-level access for methods that process data stored
/// in CSC format directly. /// in CSC format directly.
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) { pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.lane.minor_indices, self.lane.values) self.lane.indices_and_values_mut()
} }
/// Returns a mutable entry for the given global row index. /// Returns a mutable entry for the given global row index.
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> { pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.lane.minor_dim, self.lane.get_entry_mut(global_row_index)
self.lane.minor_indices,
self.lane.values,
global_row_index)
} }
} }

View File

@ -2,7 +2,7 @@
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut}; use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::csc::CscMatrix; use crate::csc::CscMatrix;
use crate::cs::{CsMatrix, get_entry_from_slices, get_mut_entry_from_slices, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut}; use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
use nalgebra::Scalar; use nalgebra::Scalar;
use num_traits::Zero; use num_traits::Zero;
@ -20,7 +20,7 @@ use std::slice::{IterMut, Iter};
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrMatrix<T> { pub struct CsrMatrix<T> {
// Rows are major, cols are minor in the sparsity pattern // Rows are major, cols are minor in the sparsity pattern
cs: CsMatrix<T>, pub(crate) cs: CsMatrix<T>,
} }
impl<T> CsrMatrix<T> { impl<T> CsrMatrix<T> {
@ -435,37 +435,34 @@ macro_rules! impl_csr_row_common_methods {
/// The number of global columns in the row. /// The number of global columns in the row.
#[inline] #[inline]
pub fn ncols(&self) -> usize { pub fn ncols(&self) -> usize {
self.lane.minor_dim self.lane.minor_dim()
} }
/// The number of non-zeros in this row. /// The number of non-zeros in this row.
#[inline] #[inline]
pub fn nnz(&self) -> usize { pub fn nnz(&self) -> usize {
self.lane.minor_indices.len() self.lane.nnz()
} }
/// The column indices corresponding to explicitly stored entries in this row. /// The column indices corresponding to explicitly stored entries in this row.
#[inline] #[inline]
pub fn col_indices(&self) -> &[usize] { pub fn col_indices(&self) -> &[usize] {
self.lane.minor_indices self.lane.minor_indices()
} }
/// The values corresponding to explicitly stored entries in this row. /// The values corresponding to explicitly stored entries in this row.
#[inline] #[inline]
pub fn values(&self) -> &[T] { pub fn values(&self) -> &[T] {
self.lane.values self.lane.values()
} }
/// Returns an entry for the given global column index. /// Returns an entry for the given global column index.
/// ///
/// Each call to this function incurs the cost of a binary search among the explicitly /// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries. /// stored column entries.
#[inline]
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> { pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
get_entry_from_slices( self.lane.get_entry(global_col_index)
self.lane.minor_dim,
self.lane.minor_indices,
self.lane.values,
global_col_index)
} }
} }
} }
@ -476,8 +473,9 @@ impl_csr_row_common_methods!(CsrRowMut<'a, T>);
impl<'a, T> CsrRowMut<'a, T> { impl<'a, T> CsrRowMut<'a, T> {
/// Mutable access to the values corresponding to explicitly stored entries in this row. /// Mutable access to the values corresponding to explicitly stored entries in this row.
#[inline]
pub fn values_mut(&mut self) -> &mut [T] { pub fn values_mut(&mut self) -> &mut [T] {
self.lane.values self.lane.values_mut()
} }
/// Provides simultaneous access to column indices and mutable values corresponding to the /// Provides simultaneous access to column indices and mutable values corresponding to the
@ -485,16 +483,15 @@ impl<'a, T> CsrRowMut<'a, T> {
/// ///
/// This method primarily facilitates low-level access for methods that process data stored /// This method primarily facilitates low-level access for methods that process data stored
/// in CSR format directly. /// in CSR format directly.
#[inline]
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) { pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.lane.minor_indices, self.lane.values) self.lane.indices_and_values_mut()
} }
/// Returns a mutable entry for the given global column index. /// Returns a mutable entry for the given global column index.
#[inline]
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> { pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(self.lane.minor_dim, self.lane.get_entry_mut(global_col_index)
self.lane.minor_indices,
self.lane.values,
global_col_index)
} }
} }

View File

@ -90,7 +90,7 @@ pub mod pattern;
pub mod ops; pub mod ops;
pub mod convert; pub mod convert;
mod cs; pub(crate) mod cs;
#[cfg(feature = "proptest-support")] #[cfg(feature = "proptest-support")]
pub mod proptest; pub mod proptest;

View File

@ -1,96 +1,99 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::csc::CscMatrix;
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
use crate::ops::serial::{spadd_csr_prealloc, spadd_pattern, spmm_pattern, spmm_csr_prealloc}; use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
use nalgebra::{ClosedAdd, ClosedMul, Scalar}; use nalgebra::{ClosedAdd, ClosedMul, Scalar};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
use std::sync::Arc; use std::sync::Arc;
use crate::ops::{Op}; use crate::ops::{Op};
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T> /// Helper macro for implementing binary operators for different matrix types
where
// TODO: Consider introducing wrapper trait for these things? It's technically a "Ring",
// I guess...
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = CsrMatrix<T>;
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
let pattern = spadd_pattern(self.pattern(), rhs.pattern());
let values = vec![T::zero(); pattern.nnz()];
// We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
.unwrap();
spadd_csr_prealloc(T::zero(), &mut result, T::one(), Op::NoOp(&self)).unwrap();
spadd_csr_prealloc(T::one(), &mut result, T::one(), Op::NoOp(&rhs)).unwrap();
result
}
}
impl<'a, T> Add<&'a CsrMatrix<T>> for CsrMatrix<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = CsrMatrix<T>;
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
spadd_csr_prealloc(T::one(), &mut self, T::one(), Op::NoOp(rhs)).unwrap();
self
} else {
&self + rhs
}
}
}
impl<'a, T> Add<CsrMatrix<T>> for &'a CsrMatrix<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = CsrMatrix<T>;
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
rhs + self
}
}
impl<T> Add<CsrMatrix<T>> for CsrMatrix<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
type Output = Self;
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
self + &rhs
}
}
/// Helper macro for implementing matrix multiplication for different matrix types
/// See below for usage. /// See below for usage.
macro_rules! impl_matrix_mul { macro_rules! impl_bin_op {
(<$($life:lifetime),*>($a_name:ident : $a:ty, $b_name:ident : $b:ty) -> $ret:ty $body:block) ($trait:ident, $method:ident,
<$($life:lifetime),*>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
=> =>
{ {
impl<$($life,)* T> Mul<$b> for $a impl<$($life,)* T> $trait<$b_type> for $a_type
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
type Output = $ret; type Output = $ret;
fn mul(self, rhs: $b) -> Self::Output { fn $method(self, rhs: $b_type) -> Self::Output {
let $a_name = self; let $a = self;
let $b_name = rhs; let $b = rhs;
$body $body
} }
} }
} }
} }
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> { macro_rules! impl_add {
let pattern = spmm_pattern(a.pattern(), b.pattern()); ($($args:tt)*) => {
impl_bin_op!(Add, add, $($args)*);
}
}
/// Implements a + b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix.
macro_rules! impl_spadd {
($matrix_type:ident, $spadd_fn:ident) => {
impl_add!(<'a>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
// If both matrices have the same pattern, then we can immediately re-use it
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
Arc::clone(a.pattern())
} else {
Arc::new(spadd_pattern(a.pattern(), b.pattern()))
};
let values = vec![T::zero(); pattern.nnz()]; let values = vec![T::zero(); pattern.nnz()];
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values) // We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap(); .unwrap();
spmm_csr_prealloc(T::zero(), $spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
$spadd_fn(T::one(), &mut result, T::one(), Op::NoOp(&b)).unwrap();
result
});
impl_add!(<'a>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let mut a = a;
if Arc::ptr_eq(a.pattern(), b.pattern()) {
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap();
a
} else {
&a + b
}
});
impl_add!(<'a>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
b + a
});
impl_add!(<>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a + &b
});
}
}
impl_spadd!(CsrMatrix, spadd_csr_prealloc);
impl_spadd!(CscMatrix, spadd_csc_prealloc);
macro_rules! impl_mul {
($($args:tt)*) => {
impl_bin_op!(Mul, mul, $($args)*);
}
}
/// Implements a + b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix.
macro_rules! impl_spmm {
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
impl_mul!(<'a>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let pattern = $pattern_fn(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values)
.unwrap();
$spmm_fn(T::zero(),
&mut result, &mut result,
T::one(), T::one(),
Op::NoOp(a), Op::NoOp(a),
@ -98,6 +101,12 @@ impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T>
.expect("Internal error: spmm failed (please debug)."); .expect("Internal error: spmm failed (please debug).");
result result
}); });
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { a * &b}); impl_mul!(<'a>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
impl_matrix_mul!(<'a>(a: CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> { &a * b}); impl_mul!(<'a>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
impl_matrix_mul!(<>(a: CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { &a * &b}); impl_mul!(<>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
}
}
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
// Need to switch order of operations for CSC pattern
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);

View File

@ -48,6 +48,17 @@ impl<T> Op<T> {
Op::NoOp(obj) | Op::Transpose(obj) => obj, Op::NoOp(obj) | Op::Transpose(obj) => obj,
} }
} }
/// Applies the transpose operation.
///
/// This operation follows the usual semantics of transposition. In particular, double
/// transposition is equivalent to no transposition.
pub fn transposed(self) -> Self {
match self {
Op::NoOp(obj) => Op::Transpose(obj),
Op::Transpose(obj) => Op::NoOp(obj)
}
}
} }
impl<T> From<T> for Op<T> { impl<T> From<T> for Op<T> {

View File

@ -0,0 +1,194 @@
use crate::cs::CsMatrix;
use crate::ops::Op;
use crate::ops::serial::{OperationErrorType, OperationError};
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
use num_traits::{Zero, One};
use crate::SparseEntryMut;
use std::sync::Arc;
fn spmm_cs_unexpected_entry() -> OperationError {
OperationError::from_type_and_message(
OperationErrorType::InvalidPattern,
String::from("Found unexpected entry that is not present in `c`."))
}
/// Helper functionality for implementing CSR/CSC SPMM.
///
/// Since CSR/CSC matrices are basically transpositions of each other, which lets us use the same
/// algorithm for the SPMM implementation. The implementation here is written in a CSR-centric
/// manner. This means that when using it for CSC, the order of the matrices needs to be
/// reversed (since transpose(AB) = transpose(B) * transpose(A) and CSC(A) = transpose(CSR(A)).
///
/// We assume here that the matrices have already been verified to be dimensionally compatible.
pub fn spmm_cs_prealloc<T>(
beta: T,
c: &mut CsMatrix<T>,
alpha: T,
a: &CsMatrix<T>,
b: &CsMatrix<T>)
-> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
for i in 0 .. c.pattern().major_dim() {
let a_lane_i = a.get_lane(i).unwrap();
let mut c_lane_i = c.get_lane_mut(i).unwrap();
for c_ij in c_lane_i.values_mut() {
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
}
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
let b_lane_k = b.get_lane(k).unwrap();
let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
// Determine the location in C to append the value
let (c_local_idx, _) = c_lane_i_cols.iter()
.enumerate()
.find(|(_, c_col)| *c_col == j)
.ok_or_else(spmm_cs_unexpected_entry)?;
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
c_lane_i_cols = &c_lane_i_cols[c_local_idx ..];
c_lane_i_values = &mut c_lane_i_values[c_local_idx ..];
}
}
}
Ok(())
}
fn spadd_cs_unexpected_entry() -> OperationError {
OperationError::from_type_and_message(
OperationErrorType::InvalidPattern,
String::from("Found entry in `op(a)` that is not present in `c`."))
}
/// Helper functionality for implementing CSR/CSC SPADD.
pub fn spadd_cs_prealloc<T>(beta: T,
c: &mut CsMatrix<T>,
alpha: T,
a: Op<&CsMatrix<T>>)
-> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
// so we only need to sum the value arrays
// TODO: Test this fast path
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
}
Ok(())
} else {
match a {
Op::NoOp(a) => {
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
if beta != T::one() {
for c_ij in c_lane_i.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
// TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C.
let (c_idx, _) = c_minors.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_cs_unexpected_entry)?;
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
c_minors = &c_minors[c_idx ..];
c_vals = &mut c_vals[c_idx ..];
}
}
}
Op::Transpose(a) => {
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_lane_i) in a.lane_iter().enumerate() {
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.get_entry_mut(j, i).unwrap() {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
}
}
}
}
}
Ok(())
}
}
/// Helper functionality for implementing CSR/CSC SPMM.
///
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
/// the transposed operation must be specified for the CSC matrix.
pub fn spmm_cs_dense<T>(beta: T,
mut c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CsMatrix<T>>,
b: Op<DMatrixSlice<T>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
match a {
Op::NoOp(a) => {
for j in 0..c.ncols() {
let mut c_col_j = c.column_mut(j);
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
let b_contrib =
match b {
Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k))
};
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
}
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
}
}
},
Op::Transpose(a) => {
// In this case, we have to pre-multiply C by beta
c *= beta;
for k in 0..a.pattern().major_dim() {
let a_row_k = a.get_lane(k).unwrap();
for (&i, a_ki) in a_row_k.minor_indices().iter().zip(a_row_k.values()) {
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
let mut c_row_i = c.row_mut(i);
match b {
Op::NoOp(ref b) => {
let b_row_k = b.row(k);
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
}
},
Op::Transpose(ref b) => {
let b_col_k = b.column(k);
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
}
},
}
}
}
},
}
}

View File

@ -0,0 +1,92 @@
use crate::csc::CscMatrix;
use crate::ops::Op;
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
use crate::ops::serial::OperationError;
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
use num_traits::{Zero, One};
use std::borrow::Cow;
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csc_dense<'a, T>(beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
let b = b.convert();
spmm_csc_dense_(beta, c.into(), alpha, a, b)
}
fn spmm_csc_dense_<T>(beta: T,
c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<DMatrixSlice<T>>)
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
assert_compatible_spmm_dims!(c, a, b);
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
let a = a.transposed().map_same_op(|a| &a.cs);
spmm_cs_dense(beta, c, alpha, a, b)
}
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
///
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
/// returned.
pub fn spadd_csc_prealloc<T>(beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>)
-> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
}
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>)
-> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
assert_compatible_spmm_dims!(c, a, b);
use Op::{NoOp, Transpose};
match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => {
// Note: We have to reverse the order for CSC matrices
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
},
_ => {
// Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition
let a_ref: &CscMatrix<T> = a.inner_ref();
let b_ref: &CscMatrix<T> = b.inner_ref();
let (a, b) = {
use Cow::*;
match (&a, &b) {
(NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
}
};
spmm_csc_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
}
}
}

View File

@ -1,11 +1,10 @@
use crate::csr::CsrMatrix; use crate::csr::CsrMatrix;
use crate::ops::{Op}; use crate::ops::{Op};
use crate::SparseEntryMut; use crate::ops::serial::{OperationError};
use crate::ops::serial::{OperationError, OperationErrorType};
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut}; use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
use num_traits::{Zero, One}; use num_traits::{Zero, One};
use std::sync::Arc;
use std::borrow::Cow; use std::borrow::Cow;
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`. /// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csr_dense<'a, T>(beta: T, pub fn spmm_csr_dense<'a, T>(beta: T,
@ -21,7 +20,7 @@ pub fn spmm_csr_dense<'a, T>(beta: T,
} }
fn spmm_csr_dense_<T>(beta: T, fn spmm_csr_dense_<T>(beta: T,
mut c: DMatrixSliceMut<T>, c: DMatrixSliceMut<T>,
alpha: T, alpha: T,
a: Op<&CsrMatrix<T>>, a: Op<&CsrMatrix<T>>,
b: Op<DMatrixSlice<T>>) b: Op<DMatrixSlice<T>>)
@ -29,58 +28,7 @@ where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
assert_compatible_spmm_dims!(c, a, b); assert_compatible_spmm_dims!(c, a, b);
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
match a {
Op::NoOp(ref a) => {
for j in 0..c.ncols() {
let mut c_col_j = c.column_mut(j);
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let b_contrib =
match b {
Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k))
};
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
}
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
}
}
},
Op::Transpose(ref a) => {
// In this case, we have to pre-multiply C by beta
c *= beta;
for k in 0..a.nrows() {
let a_row_k = a.row(k);
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
let mut c_row_i = c.row_mut(i);
match b {
Op::NoOp(ref b) => {
let b_row_k = b.row(k);
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
}
},
Op::Transpose(ref b) => {
let b_col_k = b.column(k);
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
}
},
}
}
}
},
}
}
fn spadd_csr_unexpected_entry() -> OperationError {
OperationError::from_type_and_message(
OperationErrorType::InvalidPattern,
String::from("Found entry in `a` that is not present in `c`."))
} }
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`. /// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
@ -96,70 +44,7 @@ where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
assert_compatible_spadd_dims!(c, a); assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
// so we only need to sum the value arrays
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
}
Ok(())
} else {
match a {
Op::NoOp(a) => {
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
if beta != T::one() {
for c_ij in c_row_i.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
// TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C.
let (c_idx, _) = c_cols.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_csr_unexpected_entry)?;
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
c_cols = &c_cols[c_idx ..];
c_vals = &mut c_vals[c_idx ..];
}
}
}
Op::Transpose(a) => {
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_row_i) in a.row_iter().enumerate() {
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.index_entry_mut(j, i) {
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
}
}
}
}
}
Ok(())
}
}
fn spmm_csr_unexpected_entry() -> OperationError {
OperationError::from_type_and_message(
OperationErrorType::InvalidPattern,
String::from("Found unexpected entry that is not present in `c`."))
} }
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`. /// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
@ -179,29 +64,7 @@ where
match (&a, &b) { match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => { (NoOp(ref a), NoOp(ref b)) => {
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) { spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
for c_ij in c_row_i.values_mut() {
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
}
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
let b_row_k = b.row(k);
let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut();
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) {
// Determine the location in C to append the value
let (c_local_idx, _) = c_row_i_cols.iter()
.enumerate()
.find(|(_, c_col)| *c_col == j)
.ok_or_else(spmm_csr_unexpected_entry)?;
c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
c_row_i_cols = &c_row_i_cols[c_local_idx ..];
c_row_i_values = &mut c_row_i_values[c_local_idx ..];
}
}
}
Ok(())
}, },
_ => { _ => {
// Currently we handle transposition by explicitly precomputing transposed matrices // Currently we handle transposition by explicitly precomputing transposed matrices

View File

@ -49,9 +49,12 @@ macro_rules! assert_compatible_spadd_dims {
} }
} }
mod csc;
mod csr; mod csr;
mod pattern; mod pattern;
mod cs;
pub use csc::*;
pub use csr::*; pub use csr::*;
pub use pattern::*; pub use pattern::*;

View File

@ -30,9 +30,9 @@ pub const PROPTEST_MAX_NNZ: usize = 40;
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5; pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> { pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
} }
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> { pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
csc(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ) csc(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
} }

View File

@ -1,9 +1,12 @@
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ, use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
PROPTEST_I32_VALUE_STRATEGY}; PROPTEST_I32_VALUE_STRATEGY};
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr_prealloc, spmm_csr_prealloc}; use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
spadd_csr_prealloc, spadd_csc_prealloc,
spmm_csr_prealloc, spmm_csc_prealloc};
use nalgebra_sparse::ops::{Op}; use nalgebra_sparse::ops::{Op};
use nalgebra_sparse::csr::CsrMatrix; use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csr, sparsity_pattern}; use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
use nalgebra_sparse::pattern::SparsityPattern; use nalgebra_sparse::pattern::SparsityPattern;
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice}; use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
@ -23,6 +26,15 @@ fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
DMatrix::from(&boolean_csr) DMatrix::from(&boolean_csr)
} }
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
let boolean_csc = CscMatrix::try_from_pattern_and_values(
Arc::new(pattern.clone()),
vec![1; pattern.nnz()])
.unwrap();
DMatrix::from(&boolean_csc)
}
#[derive(Debug)] #[derive(Debug)]
struct SpmmCsrDenseArgs<T: Scalar> { struct SpmmCsrDenseArgs<T: Scalar> {
c: DMatrix<T>, c: DMatrix<T>,
@ -32,6 +44,15 @@ struct SpmmCsrDenseArgs<T: Scalar> {
b: Op<DMatrix<T>>, b: Op<DMatrix<T>>,
} }
#[derive(Debug)]
struct SpmmCscDenseArgs<T: Scalar> {
c: DMatrix<T>,
beta: T,
alpha: T,
a: Op<CscMatrix<T>>,
b: Op<DMatrix<T>>,
}
/// Returns matrices C, A and B with compatible dimensions such that it can be used /// Returns matrices C, A and B with compatible dimensions such that it can be used
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`. /// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> { fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
@ -70,6 +91,21 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
}) })
} }
/// Returns matrices C, A and B with compatible dimensions such that it can be used
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value=SpmmCscDenseArgs<i32>> {
spmm_csr_dense_args_strategy()
.prop_map(|args| {
SpmmCscDenseArgs {
c: args.c,
beta: args.beta,
alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
b: args.b
}
})
}
#[derive(Debug)] #[derive(Debug)]
struct SpaddCsrArgs<T> { struct SpaddCsrArgs<T> {
c: CsrMatrix<T>, c: CsrMatrix<T>,
@ -78,6 +114,14 @@ struct SpaddCsrArgs<T> {
a: Op<CsrMatrix<T>>, a: Op<CsrMatrix<T>>,
} }
#[derive(Debug)]
struct SpaddCscArgs<T> {
c: CscMatrix<T>,
beta: T,
alpha: T,
a: Op<CscMatrix<T>>,
}
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> { fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
let value_strategy = PROPTEST_I32_VALUE_STRATEGY; let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
@ -99,6 +143,16 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
}) })
} }
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value=SpaddCscArgs<i32>> {
spadd_csr_prealloc_args_strategy()
.prop_map(|args| SpaddCscArgs {
c: CscMatrix::from(&args.c),
beta: args.beta,
alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a))
})
}
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> { fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM) matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
} }
@ -150,6 +204,15 @@ struct SpmmCsrArgs<T> {
b: Op<CsrMatrix<T>>, b: Op<CsrMatrix<T>>,
} }
#[derive(Debug)]
struct SpmmCscArgs<T> {
c: CscMatrix<T>,
beta: T,
alpha: T,
a: Op<CscMatrix<T>>,
b: Op<CscMatrix<T>>,
}
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> { fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
spmm_pattern_strategy() spmm_pattern_strategy()
.prop_flat_map(|(a_pattern, b_pattern)| { .prop_flat_map(|(a_pattern, b_pattern)| {
@ -181,6 +244,21 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
}) })
} }
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value=SpmmCscArgs<i32>> {
// Note: Converting from CSR is simple, but might be significantly slower than
// writing a common implementation that can be shared between CSR and CSC args
spmm_csr_prealloc_args_strategy()
.prop_map(|args| {
SpmmCscArgs {
c: CscMatrix::from(&args.c),
beta: args.beta,
alpha: args.alpha,
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
b: args.b.map_same_op(|b| CscMatrix::from(&b))
}
})
}
/// Helper function to help us call dense GEMM with our `Op` type /// Helper function to help us call dense GEMM with our `Op` type
fn dense_gemm<'a>(beta: i32, fn dense_gemm<'a>(beta: i32,
c: impl Into<DMatrixSliceMut<'a, i32>>, c: impl Into<DMatrixSliceMut<'a, i32>>,
@ -310,7 +388,7 @@ proptest! {
(a, b) (a, b)
in csr_strategy() in csr_strategy()
.prop_flat_map(|a| { .prop_flat_map(|a| {
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), 40); let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
(Just(a), b) (Just(a), b)
})) }))
{ {
@ -500,4 +578,263 @@ proptest! {
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense); prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern); prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
} }
#[test]
fn spmm_csc_prealloc_test(SpmmCscArgs { c, beta, alpha, a, b }
in spmm_csc_prealloc_args_strategy()
) {
// Test that we get the expected result by comparing to an equivalent dense operation
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
let mut c_sparse = c.clone();
spmm_csc_prealloc(beta, &mut c_sparse, alpha, a.as_ref(), b.as_ref()).unwrap();
let mut c_dense = DMatrix::from(&c);
let op_a_dense = match a {
Op::NoOp(ref a) => DMatrix::from(a),
Op::Transpose(ref a) => DMatrix::from(a).transpose(),
};
let op_b_dense = match b {
Op::NoOp(ref b) => DMatrix::from(b),
Op::Transpose(ref b) => DMatrix::from(b).transpose(),
};
c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense;
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
}
#[test]
fn spmm_csc_prealloc_panics_on_dim_mismatch(
(alpha, beta, c, a, b)
in (PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_I32_VALUE_STRATEGY,
csc_strategy(),
op_strategy(csc_strategy()),
op_strategy(csc_strategy()))
) {
// We refer to `A * B` as the "product"
let product_rows = match &a {
Op::NoOp(ref a) => a.nrows(),
Op::Transpose(ref a) => a.ncols(),
};
let product_cols = match &b {
Op::NoOp(ref b) => b.ncols(),
Op::Transpose(ref b) => b.nrows(),
};
// Determine the common dimension in the product
// from the perspective of a and b, respectively
let product_a_common = match &a {
Op::NoOp(ref a) => a.ncols(),
Op::Transpose(ref a) => a.nrows(),
};
let product_b_common = match &b {
Op::NoOp(ref b) => b.nrows(),
Op::Transpose(ref b) => b.ncols(),
};
let dims_are_compatible = product_rows == c.nrows()
&& product_cols == c.ncols()
&& product_a_common == product_b_common;
// If the dimensions randomly happen to be compatible, then of course we need to
// skip the test, so we assume that they are not.
prop_assume!(!dims_are_compatible);
let result = catch_unwind(|| {
let mut spmm_result = c.clone();
spmm_csc_prealloc(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()).unwrap();
});
prop_assert!(result.is_err(),
"The SPMM kernel executed successfully despite mismatch dimensions");
}
#[test]
fn csc_mul_csc(
// a and b have dimensions compatible for multiplication
(a, b)
in csc_strategy()
.prop_flat_map(|a| {
let max_nnz = PROPTEST_MAX_NNZ;
let cols = PROPTEST_MATRIX_DIM;
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
(Just(a), b)
})
.prop_map(|(a, b)| {
println!("a: {} x {}, b: {} x {}", a.nrows(), a.ncols(), b.nrows(), b.ncols());
(a, b)
}))
{
assert_eq!(a.ncols(), b.nrows());
// We use the dense result as the ground truth for the arithmetic result
let c_dense = DMatrix::from(&a) * DMatrix::from(&b);
// However, it's not enough only to cover the dense result, we also need to verify the
// sparsity pattern. We can determine the exact sparsity pattern by using
// dense arithmetic with positive integer values and extracting positive entries.
let c_dense_pattern = dense_csc_pattern(a.pattern()) * dense_csc_pattern(b.pattern());
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
// Check each combination of owned matrices and references
let c_owned_owned = a.clone() * b.clone();
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
let c_owned_ref = a.clone() * &b;
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
let c_ref_owned = &a * b.clone();
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
let c_ref_ref = &a * &b;
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
#[test]
fn spmm_csc_dense_agrees_with_dense_result(
SpmmCscDenseArgs { c, beta, alpha, a, b }
in spmm_csc_dense_args_strategy()
) {
let mut spmm_result = c.clone();
spmm_csc_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref());
let mut gemm_result = c.clone();
let a_dense = a.map_same_op(|a| DMatrix::from(&a));
dense_gemm(beta, &mut gemm_result, alpha, a_dense.as_ref(), b.as_ref());
prop_assert_eq!(spmm_result, gemm_result);
}
#[test]
fn spmm_csc_dense_panics_on_dim_mismatch(
(alpha, beta, c, a, b)
in (PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_I32_VALUE_STRATEGY,
dense_strategy(),
op_strategy(csc_strategy()),
op_strategy(dense_strategy()))
) {
// We refer to `A * B` as the "product"
let product_rows = match &a {
Op::NoOp(ref a) => a.nrows(),
Op::Transpose(ref a) => a.ncols(),
};
let product_cols = match &b {
Op::NoOp(ref b) => b.ncols(),
Op::Transpose(ref b) => b.nrows(),
};
// Determine the common dimension in the product
// from the perspective of a and b, respectively
let product_a_common = match &a {
Op::NoOp(ref a) => a.ncols(),
Op::Transpose(ref a) => a.nrows(),
};
let product_b_common = match &b {
Op::NoOp(ref b) => b.nrows(),
Op::Transpose(ref b) => b.ncols()
};
let dims_are_compatible = product_rows == c.nrows()
&& product_cols == c.ncols()
&& product_a_common == product_b_common;
// If the dimensions randomly happen to be compatible, then of course we need to
// skip the test, so we assume that they are not.
prop_assume!(!dims_are_compatible);
let result = catch_unwind(|| {
let mut spmm_result = c.clone();
spmm_csc_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref());
});
prop_assert!(result.is_err(),
"The SPMM kernel executed successfully despite mismatch dimensions");
}
#[test]
fn spadd_csc_prealloc_test(SpaddCscArgs { c, beta, alpha, a } in spadd_csc_prealloc_args_strategy()) {
// Test that we get the expected result by comparing to an equivalent dense operation
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
let mut c_sparse = c.clone();
spadd_csc_prealloc(beta, &mut c_sparse, alpha, a.as_ref()).unwrap();
let mut c_dense = DMatrix::from(&c);
let op_a_dense = match a {
Op::NoOp(a) => DMatrix::from(&a),
Op::Transpose(a) => DMatrix::from(&a).transpose(),
};
c_dense = beta * c_dense + alpha * &op_a_dense;
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
}
#[test]
fn spadd_csc_prealloc_panics_on_dim_mismatch(
(alpha, beta, c, op_a)
in (PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_I32_VALUE_STRATEGY,
csc_strategy(),
op_strategy(csc_strategy()))
) {
let op_a_rows = match &op_a {
&Op::NoOp(ref a) => a.nrows(),
&Op::Transpose(ref a) => a.ncols()
};
let op_a_cols = match &op_a {
&Op::NoOp(ref a) => a.ncols(),
&Op::Transpose(ref a) => a.nrows()
};
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
// If the dimensions randomly happen to be compatible, then of course we need to
// skip the test, so we assume that they are not.
prop_assume!(!dims_are_compatible);
let result = catch_unwind(|| {
let mut spmm_result = c.clone();
spadd_csc_prealloc(beta, &mut spmm_result, alpha, op_a.as_ref()).unwrap();
});
prop_assert!(result.is_err(),
"The SPMM kernel executed successfully despite mismatch dimensions");
}
#[test]
fn csc_add_csc(
// a and b have the same dimensions
(a, b)
in csc_strategy()
.prop_flat_map(|a| {
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
(Just(a), b)
}))
{
// We use the dense result as the ground truth for the arithmetic result
let c_dense = DMatrix::from(&a) + DMatrix::from(&b);
// However, it's not enough only to cover the dense result, we also need to verify the
// sparsity pattern. We can determine the exact sparsity pattern by using
// dense arithmetic with positive integer values and extracting positive entries.
let c_dense_pattern = dense_csc_pattern(a.pattern()) + dense_csc_pattern(b.pattern());
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
// Check each combination of owned matrices and references
let c_owned_owned = a.clone() + b.clone();
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
let c_owned_ref = a.clone() + &b;
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
let c_ref_owned = &a + b.clone();
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
let c_ref_ref = &a + &b;
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
}
} }