Implement arithmetic operations for CSC matrices
This commit is contained in:
parent
6a1d12705f
commit
dbdf5567fc
|
@ -144,9 +144,19 @@ impl<T> CsMatrix<T> {
|
||||||
values: &mut values[range]
|
values: &mut values[range]
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn lane_iter(&self) -> CsLaneIter<T> {
|
||||||
|
CsLaneIter::new(self.pattern().as_ref(), self.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
||||||
|
CsLaneIterMut::new(self.sparsity_pattern.as_ref(), &mut self.values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_entry_from_slices<'a, T>(
|
fn get_entry_from_slices<'a, T>(
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
values: &'a [T],
|
values: &'a [T],
|
||||||
|
@ -161,7 +171,7 @@ pub fn get_entry_from_slices<'a, T>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_mut_entry_from_slices<'a, T>(
|
fn get_mut_entry_from_slices<'a, T>(
|
||||||
minor_dim: usize,
|
minor_dim: usize,
|
||||||
minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
values: &'a mut [T],
|
values: &'a mut [T],
|
||||||
|
@ -178,16 +188,16 @@ pub fn get_mut_entry_from_slices<'a, T>(
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsLane<'a, T> {
|
pub struct CsLane<'a, T> {
|
||||||
pub minor_dim: usize,
|
minor_dim: usize,
|
||||||
pub minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
pub values: &'a [T]
|
values: &'a [T]
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct CsLaneMut<'a, T> {
|
pub struct CsLaneMut<'a, T> {
|
||||||
pub minor_dim: usize,
|
minor_dim: usize,
|
||||||
pub minor_indices: &'a [usize],
|
minor_indices: &'a [usize],
|
||||||
pub values: &'a mut [T]
|
values: &'a mut [T]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CsLaneIter<'a, T> {
|
pub struct CsLaneIter<'a, T> {
|
||||||
|
@ -280,4 +290,59 @@ impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Implement the methods common to both CsLane and CsLaneMut. See the documentation for the
|
||||||
|
/// methods delegated here by CsrMatrix and CscMatrix members for more information.
|
||||||
|
macro_rules! impl_cs_lane_common_methods {
|
||||||
|
($name:ty) => {
|
||||||
|
impl<'a, T> $name {
|
||||||
|
#[inline]
|
||||||
|
pub fn minor_dim(&self) -> usize {
|
||||||
|
self.minor_dim
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn nnz(&self) -> usize {
|
||||||
|
self.minor_indices.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn minor_indices(&self) -> &[usize] {
|
||||||
|
self.minor_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn values(&self) -> &[T] {
|
||||||
|
self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
|
get_entry_from_slices(
|
||||||
|
self.minor_dim,
|
||||||
|
self.minor_indices,
|
||||||
|
self.values,
|
||||||
|
global_col_index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
||||||
|
impl_cs_lane_common_methods!(CsLaneMut<'a, T>);
|
||||||
|
|
||||||
|
impl<'a, T> CsLaneMut<'a, T> {
|
||||||
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
|
self.values
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn indices_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
|
(self.minor_indices, self.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
|
get_mut_entry_from_slices(self.minor_dim,
|
||||||
|
self.minor_indices,
|
||||||
|
self.values,
|
||||||
|
global_minor_index)
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,8 +3,7 @@
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut,
|
use crate::cs::{CsMatrix, CsLane, CsLaneMut, CsLaneIter, CsLaneIterMut};
|
||||||
get_entry_from_slices, get_mut_entry_from_slices};
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::slice::{IterMut, Iter};
|
use std::slice::{IterMut, Iter};
|
||||||
|
@ -21,7 +20,7 @@ use nalgebra::Scalar;
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CscMatrix<T> {
|
pub struct CscMatrix<T> {
|
||||||
// Cols are major, rows are minor in the sparsity pattern
|
// Cols are major, rows are minor in the sparsity pattern
|
||||||
cs: CsMatrix<T>,
|
pub(crate) cs: CsMatrix<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CscMatrix<T> {
|
impl<T> CscMatrix<T> {
|
||||||
|
@ -435,25 +434,25 @@ macro_rules! impl_csc_col_common_methods {
|
||||||
/// The number of global rows in the column.
|
/// The number of global rows in the column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nrows(&self) -> usize {
|
pub fn nrows(&self) -> usize {
|
||||||
self.lane.minor_dim
|
self.lane.minor_dim()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of non-zeros in this column.
|
/// The number of non-zeros in this column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nnz(&self) -> usize {
|
pub fn nnz(&self) -> usize {
|
||||||
self.lane.minor_indices.len()
|
self.lane.nnz()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The row indices corresponding to explicitly stored entries in this column.
|
/// The row indices corresponding to explicitly stored entries in this column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn row_indices(&self) -> &[usize] {
|
pub fn row_indices(&self) -> &[usize] {
|
||||||
self.lane.minor_indices
|
self.lane.minor_indices()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The values corresponding to explicitly stored entries in this column.
|
/// The values corresponding to explicitly stored entries in this column.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
self.lane.values
|
self.lane.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an entry for the given global row index.
|
/// Returns an entry for the given global row index.
|
||||||
|
@ -461,11 +460,7 @@ macro_rules! impl_csc_col_common_methods {
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored row entries.
|
/// stored row entries.
|
||||||
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> {
|
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> {
|
||||||
get_entry_from_slices(
|
self.lane.get_entry(global_row_index)
|
||||||
self.lane.minor_dim,
|
|
||||||
self.lane.minor_indices,
|
|
||||||
self.lane.values,
|
|
||||||
global_row_index)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -477,7 +472,7 @@ impl_csc_col_common_methods!(CscColMut<'a, T>);
|
||||||
impl<'a, T> CscColMut<'a, T> {
|
impl<'a, T> CscColMut<'a, T> {
|
||||||
/// Mutable access to the values corresponding to explicitly stored entries in this column.
|
/// Mutable access to the values corresponding to explicitly stored entries in this column.
|
||||||
pub fn values_mut(&mut self) -> &mut [T] {
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
self.lane.values
|
self.lane.values_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provides simultaneous access to row indices and mutable values corresponding to the
|
/// Provides simultaneous access to row indices and mutable values corresponding to the
|
||||||
|
@ -486,15 +481,12 @@ impl<'a, T> CscColMut<'a, T> {
|
||||||
/// This method primarily facilitates low-level access for methods that process data stored
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
/// in CSC format directly.
|
/// in CSC format directly.
|
||||||
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
(self.lane.minor_indices, self.lane.values)
|
self.lane.indices_and_values_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable entry for the given global row index.
|
/// Returns a mutable entry for the given global row index.
|
||||||
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> {
|
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
get_mut_entry_from_slices(self.lane.minor_dim,
|
self.lane.get_entry_mut(global_row_index)
|
||||||
self.lane.minor_indices,
|
|
||||||
self.lane.values,
|
|
||||||
global_row_index)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
use crate::{SparseFormatError, SparseFormatErrorKind, SparseEntry, SparseEntryMut};
|
||||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||||
use crate::csc::CscMatrix;
|
use crate::csc::CscMatrix;
|
||||||
use crate::cs::{CsMatrix, get_entry_from_slices, get_mut_entry_from_slices, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
use crate::cs::{CsMatrix, CsLaneIterMut, CsLaneIter, CsLane, CsLaneMut};
|
||||||
|
|
||||||
use nalgebra::Scalar;
|
use nalgebra::Scalar;
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
|
@ -20,7 +20,7 @@ use std::slice::{IterMut, Iter};
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct CsrMatrix<T> {
|
pub struct CsrMatrix<T> {
|
||||||
// Rows are major, cols are minor in the sparsity pattern
|
// Rows are major, cols are minor in the sparsity pattern
|
||||||
cs: CsMatrix<T>,
|
pub(crate) cs: CsMatrix<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> CsrMatrix<T> {
|
impl<T> CsrMatrix<T> {
|
||||||
|
@ -435,37 +435,34 @@ macro_rules! impl_csr_row_common_methods {
|
||||||
/// The number of global columns in the row.
|
/// The number of global columns in the row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn ncols(&self) -> usize {
|
pub fn ncols(&self) -> usize {
|
||||||
self.lane.minor_dim
|
self.lane.minor_dim()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The number of non-zeros in this row.
|
/// The number of non-zeros in this row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn nnz(&self) -> usize {
|
pub fn nnz(&self) -> usize {
|
||||||
self.lane.minor_indices.len()
|
self.lane.nnz()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The column indices corresponding to explicitly stored entries in this row.
|
/// The column indices corresponding to explicitly stored entries in this row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn col_indices(&self) -> &[usize] {
|
pub fn col_indices(&self) -> &[usize] {
|
||||||
self.lane.minor_indices
|
self.lane.minor_indices()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The values corresponding to explicitly stored entries in this row.
|
/// The values corresponding to explicitly stored entries in this row.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn values(&self) -> &[T] {
|
pub fn values(&self) -> &[T] {
|
||||||
self.lane.values
|
self.lane.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an entry for the given global column index.
|
/// Returns an entry for the given global column index.
|
||||||
///
|
///
|
||||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||||
/// stored column entries.
|
/// stored column entries.
|
||||||
|
#[inline]
|
||||||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||||
get_entry_from_slices(
|
self.lane.get_entry(global_col_index)
|
||||||
self.lane.minor_dim,
|
|
||||||
self.lane.minor_indices,
|
|
||||||
self.lane.values,
|
|
||||||
global_col_index)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -476,8 +473,9 @@ impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||||
|
|
||||||
impl<'a, T> CsrRowMut<'a, T> {
|
impl<'a, T> CsrRowMut<'a, T> {
|
||||||
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
||||||
|
#[inline]
|
||||||
pub fn values_mut(&mut self) -> &mut [T] {
|
pub fn values_mut(&mut self) -> &mut [T] {
|
||||||
self.lane.values
|
self.lane.values_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
||||||
|
@ -485,16 +483,15 @@ impl<'a, T> CsrRowMut<'a, T> {
|
||||||
///
|
///
|
||||||
/// This method primarily facilitates low-level access for methods that process data stored
|
/// This method primarily facilitates low-level access for methods that process data stored
|
||||||
/// in CSR format directly.
|
/// in CSR format directly.
|
||||||
|
#[inline]
|
||||||
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||||
(self.lane.minor_indices, self.lane.values)
|
self.lane.indices_and_values_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable entry for the given global column index.
|
/// Returns a mutable entry for the given global column index.
|
||||||
|
#[inline]
|
||||||
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
||||||
get_mut_entry_from_slices(self.lane.minor_dim,
|
self.lane.get_entry_mut(global_col_index)
|
||||||
self.lane.minor_indices,
|
|
||||||
self.lane.values,
|
|
||||||
global_col_index)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ pub mod pattern;
|
||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod convert;
|
pub mod convert;
|
||||||
|
|
||||||
mod cs;
|
pub(crate) mod cs;
|
||||||
|
|
||||||
#[cfg(feature = "proptest-support")]
|
#[cfg(feature = "proptest-support")]
|
||||||
pub mod proptest;
|
pub mod proptest;
|
||||||
|
|
|
@ -1,103 +1,112 @@
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul};
|
||||||
use crate::ops::serial::{spadd_csr_prealloc, spadd_pattern, spmm_pattern, spmm_csr_prealloc};
|
use crate::ops::serial::{spadd_csr_prealloc, spadd_csc_prealloc, spadd_pattern,
|
||||||
|
spmm_pattern, spmm_csr_prealloc, spmm_csc_prealloc};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
|
|
||||||
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
/// Helper macro for implementing binary operators for different matrix types
|
||||||
where
|
|
||||||
// TODO: Consider introducing wrapper trait for these things? It's technically a "Ring",
|
|
||||||
// I guess...
|
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
||||||
{
|
|
||||||
type Output = CsrMatrix<T>;
|
|
||||||
|
|
||||||
fn add(self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
|
||||||
let pattern = spadd_pattern(self.pattern(), rhs.pattern());
|
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
|
||||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
|
||||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
|
||||||
.unwrap();
|
|
||||||
spadd_csr_prealloc(T::zero(), &mut result, T::one(), Op::NoOp(&self)).unwrap();
|
|
||||||
spadd_csr_prealloc(T::one(), &mut result, T::one(), Op::NoOp(&rhs)).unwrap();
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, T> Add<&'a CsrMatrix<T>> for CsrMatrix<T>
|
|
||||||
where
|
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
||||||
{
|
|
||||||
type Output = CsrMatrix<T>;
|
|
||||||
|
|
||||||
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
|
||||||
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
|
||||||
spadd_csr_prealloc(T::one(), &mut self, T::one(), Op::NoOp(rhs)).unwrap();
|
|
||||||
self
|
|
||||||
} else {
|
|
||||||
&self + rhs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, T> Add<CsrMatrix<T>> for &'a CsrMatrix<T>
|
|
||||||
where
|
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
||||||
{
|
|
||||||
type Output = CsrMatrix<T>;
|
|
||||||
|
|
||||||
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
|
||||||
rhs + self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Add<CsrMatrix<T>> for CsrMatrix<T>
|
|
||||||
where
|
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
|
||||||
{
|
|
||||||
type Output = Self;
|
|
||||||
|
|
||||||
fn add(self, rhs: CsrMatrix<T>) -> Self::Output {
|
|
||||||
self + &rhs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper macro for implementing matrix multiplication for different matrix types
|
|
||||||
/// See below for usage.
|
/// See below for usage.
|
||||||
macro_rules! impl_matrix_mul {
|
macro_rules! impl_bin_op {
|
||||||
(<$($life:lifetime),*>($a_name:ident : $a:ty, $b_name:ident : $b:ty) -> $ret:ty $body:block)
|
($trait:ident, $method:ident,
|
||||||
|
<$($life:lifetime),*>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
||||||
=>
|
=>
|
||||||
{
|
{
|
||||||
impl<$($life,)* T> Mul<$b> for $a
|
impl<$($life,)* T> $trait<$b_type> for $a_type
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
type Output = $ret;
|
type Output = $ret;
|
||||||
fn mul(self, rhs: $b) -> Self::Output {
|
fn $method(self, rhs: $b_type) -> Self::Output {
|
||||||
let $a_name = self;
|
let $a = self;
|
||||||
let $b_name = rhs;
|
let $b = rhs;
|
||||||
$body
|
$body
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> {
|
macro_rules! impl_add {
|
||||||
let pattern = spmm_pattern(a.pattern(), b.pattern());
|
($($args:tt)*) => {
|
||||||
let values = vec![T::zero(); pattern.nnz()];
|
impl_bin_op!(Add, add, $($args)*);
|
||||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
}
|
||||||
.unwrap();
|
}
|
||||||
spmm_csr_prealloc(T::zero(),
|
|
||||||
&mut result,
|
/// Implements a + b for all combinations of reference and owned matrices, for
|
||||||
T::one(),
|
/// CsrMatrix or CscMatrix.
|
||||||
Op::NoOp(a),
|
macro_rules! impl_spadd {
|
||||||
Op::NoOp(b))
|
($matrix_type:ident, $spadd_fn:ident) => {
|
||||||
.expect("Internal error: spmm failed (please debug).");
|
impl_add!(<'a>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
result
|
// If both matrices have the same pattern, then we can immediately re-use it
|
||||||
});
|
let pattern = if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { a * &b});
|
Arc::clone(a.pattern())
|
||||||
impl_matrix_mul!(<'a>(a: CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T> { &a * b});
|
} else {
|
||||||
impl_matrix_mul!(<>(a: CsrMatrix<T>, b: CsrMatrix<T>) -> CsrMatrix<T> { &a * &b});
|
Arc::new(spadd_pattern(a.pattern(), b.pattern()))
|
||||||
|
};
|
||||||
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
|
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||||
|
.unwrap();
|
||||||
|
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
|
||||||
|
$spadd_fn(T::one(), &mut result, T::one(), Op::NoOp(&b)).unwrap();
|
||||||
|
result
|
||||||
|
});
|
||||||
|
|
||||||
|
impl_add!(<'a>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
let mut a = a;
|
||||||
|
if Arc::ptr_eq(a.pattern(), b.pattern()) {
|
||||||
|
$spadd_fn(T::one(), &mut a, T::one(), Op::NoOp(b)).unwrap();
|
||||||
|
a
|
||||||
|
} else {
|
||||||
|
&a + b
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
impl_add!(<'a>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
b + a
|
||||||
|
});
|
||||||
|
impl_add!(<>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
a + &b
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_spadd!(CsrMatrix, spadd_csr_prealloc);
|
||||||
|
impl_spadd!(CscMatrix, spadd_csc_prealloc);
|
||||||
|
|
||||||
|
macro_rules! impl_mul {
|
||||||
|
($($args:tt)*) => {
|
||||||
|
impl_bin_op!(Mul, mul, $($args)*);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implements a + b for all combinations of reference and owned matrices, for
|
||||||
|
/// CsrMatrix or CscMatrix.
|
||||||
|
macro_rules! impl_spmm {
|
||||||
|
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
|
||||||
|
impl_mul!(<'a>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||||
|
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
||||||
|
let values = vec![T::zero(); pattern.nnz()];
|
||||||
|
let mut result = $matrix_type::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
|
.unwrap();
|
||||||
|
$spmm_fn(T::zero(),
|
||||||
|
&mut result,
|
||||||
|
T::one(),
|
||||||
|
Op::NoOp(a),
|
||||||
|
Op::NoOp(b))
|
||||||
|
.expect("Internal error: spmm failed (please debug).");
|
||||||
|
result
|
||||||
|
});
|
||||||
|
impl_mul!(<'a>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
|
||||||
|
impl_mul!(<'a>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
|
||||||
|
impl_mul!(<>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_spmm!(CsrMatrix, spmm_pattern, spmm_csr_prealloc);
|
||||||
|
// Need to switch order of operations for CSC pattern
|
||||||
|
impl_spmm!(CscMatrix, |a, b| spmm_pattern(b, a), spmm_csc_prealloc);
|
|
@ -48,6 +48,17 @@ impl<T> Op<T> {
|
||||||
Op::NoOp(obj) | Op::Transpose(obj) => obj,
|
Op::NoOp(obj) | Op::Transpose(obj) => obj,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies the transpose operation.
|
||||||
|
///
|
||||||
|
/// This operation follows the usual semantics of transposition. In particular, double
|
||||||
|
/// transposition is equivalent to no transposition.
|
||||||
|
pub fn transposed(self) -> Self {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::Transpose(obj),
|
||||||
|
Op::Transpose(obj) => Op::NoOp(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<T> for Op<T> {
|
impl<T> From<T> for Op<T> {
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
use crate::cs::CsMatrix;
|
||||||
|
use crate::ops::Op;
|
||||||
|
use crate::ops::serial::{OperationErrorType, OperationError};
|
||||||
|
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
||||||
|
use num_traits::{Zero, One};
|
||||||
|
use crate::SparseEntryMut;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||||
|
OperationError::from_type_and_message(
|
||||||
|
OperationErrorType::InvalidPattern,
|
||||||
|
String::from("Found unexpected entry that is not present in `c`."))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
|
///
|
||||||
|
/// Since CSR/CSC matrices are basically transpositions of each other, which lets us use the same
|
||||||
|
/// algorithm for the SPMM implementation. The implementation here is written in a CSR-centric
|
||||||
|
/// manner. This means that when using it for CSC, the order of the matrices needs to be
|
||||||
|
/// reversed (since transpose(AB) = transpose(B) * transpose(A) and CSC(A) = transpose(CSR(A)).
|
||||||
|
///
|
||||||
|
/// We assume here that the matrices have already been verified to be dimensionally compatible.
|
||||||
|
pub fn spmm_cs_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CsMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: &CsMatrix<T>,
|
||||||
|
b: &CsMatrix<T>)
|
||||||
|
-> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
for i in 0 .. c.pattern().major_dim() {
|
||||||
|
let a_lane_i = a.get_lane(i).unwrap();
|
||||||
|
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
||||||
|
for c_ij in c_lane_i.values_mut() {
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||||
|
let b_lane_k = b.get_lane(k).unwrap();
|
||||||
|
let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
|
||||||
|
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||||
|
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
||||||
|
// Determine the location in C to append the value
|
||||||
|
let (c_local_idx, _) = c_lane_i_cols.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == j)
|
||||||
|
.ok_or_else(spmm_cs_unexpected_entry)?;
|
||||||
|
|
||||||
|
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
c_lane_i_cols = &c_lane_i_cols[c_local_idx ..];
|
||||||
|
c_lane_i_values = &mut c_lane_i_values[c_local_idx ..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spadd_cs_unexpected_entry() -> OperationError {
|
||||||
|
OperationError::from_type_and_message(
|
||||||
|
OperationErrorType::InvalidPattern,
|
||||||
|
String::from("Found entry in `op(a)` that is not present in `c`."))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for implementing CSR/CSC SPADD.
|
||||||
|
pub fn spadd_cs_prealloc<T>(beta: T,
|
||||||
|
c: &mut CsMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsMatrix<T>>)
|
||||||
|
-> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
|
||||||
|
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
|
||||||
|
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
||||||
|
// so we only need to sum the value arrays
|
||||||
|
// TODO: Test this fast path
|
||||||
|
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
|
||||||
|
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
||||||
|
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
match a {
|
||||||
|
Op::NoOp(a) => {
|
||||||
|
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c_lane_i.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
|
||||||
|
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
|
||||||
|
|
||||||
|
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
|
||||||
|
// TODO: Use exponential search instead of linear search.
|
||||||
|
// If C has substantially more entries in the row than A, then a line search
|
||||||
|
// will needlessly visit many entries in C.
|
||||||
|
let (c_idx, _) = c_minors.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == a_col)
|
||||||
|
.ok_or_else(spadd_cs_unexpected_entry)?;
|
||||||
|
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||||
|
c_minors = &c_minors[c_idx ..];
|
||||||
|
c_vals = &mut c_vals[c_idx ..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Op::Transpose(a) => {
|
||||||
|
if beta != T::one() {
|
||||||
|
for c_ij in c.values_mut() {
|
||||||
|
*c_ij *= beta.inlined_clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, a_lane_i) in a.lane_iter().enumerate() {
|
||||||
|
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||||
|
let a_val = a_val.inlined_clone();
|
||||||
|
let alpha = alpha.inlined_clone();
|
||||||
|
match c.get_entry_mut(j, i).unwrap() {
|
||||||
|
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
||||||
|
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||||
|
///
|
||||||
|
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
|
||||||
|
/// the transposed operation must be specified for the CSC matrix.
|
||||||
|
pub fn spmm_cs_dense<T>(beta: T,
|
||||||
|
mut c: DMatrixSliceMut<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CsMatrix<T>>,
|
||||||
|
b: Op<DMatrixSlice<T>>)
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
match a {
|
||||||
|
Op::NoOp(a) => {
|
||||||
|
for j in 0..c.ncols() {
|
||||||
|
let mut c_col_j = c.column_mut(j);
|
||||||
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
|
||||||
|
let mut dot_ij = T::zero();
|
||||||
|
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
|
||||||
|
let b_contrib =
|
||||||
|
match b {
|
||||||
|
Op::NoOp(ref b) => b.index((k, j)),
|
||||||
|
Op::Transpose(ref b) => b.index((j, k))
|
||||||
|
};
|
||||||
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||||
|
}
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Op::Transpose(a) => {
|
||||||
|
// In this case, we have to pre-multiply C by beta
|
||||||
|
c *= beta;
|
||||||
|
|
||||||
|
for k in 0..a.pattern().major_dim() {
|
||||||
|
let a_row_k = a.get_lane(k).unwrap();
|
||||||
|
for (&i, a_ki) in a_row_k.minor_indices().iter().zip(a_row_k.values()) {
|
||||||
|
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
||||||
|
let mut c_row_i = c.row_mut(i);
|
||||||
|
match b {
|
||||||
|
Op::NoOp(ref b) => {
|
||||||
|
let b_row_k = b.row(k);
|
||||||
|
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Op::Transpose(ref b) => {
|
||||||
|
let b_col_k = b.column(k);
|
||||||
|
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
use crate::csc::CscMatrix;
|
||||||
|
use crate::ops::Op;
|
||||||
|
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
||||||
|
use crate::ops::serial::OperationError;
|
||||||
|
use nalgebra::{Scalar, ClosedAdd, ClosedMul, DMatrixSliceMut, DMatrixSlice};
|
||||||
|
use num_traits::{Zero, One};
|
||||||
|
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
pub fn spmm_csc_dense<'a, T>(beta: T,
|
||||||
|
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
let b = b.convert();
|
||||||
|
spmm_csc_dense_(beta, c.into(), alpha, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spmm_csc_dense_<T>(beta: T,
|
||||||
|
c: DMatrixSliceMut<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
b: Op<DMatrixSlice<T>>)
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
|
||||||
|
let a = a.transposed().map_same_op(|a| &a.cs);
|
||||||
|
spmm_cs_dense(beta, c, alpha, a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||||
|
///
|
||||||
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||||
|
/// returned.
|
||||||
|
pub fn spadd_csc_prealloc<T>(beta: T,
|
||||||
|
c: &mut CscMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>)
|
||||||
|
-> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
assert_compatible_spadd_dims!(c, a);
|
||||||
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
pub fn spmm_csc_prealloc<T>(
|
||||||
|
beta: T,
|
||||||
|
c: &mut CscMatrix<T>,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<&CscMatrix<T>>,
|
||||||
|
b: Op<&CscMatrix<T>>)
|
||||||
|
-> Result<(), OperationError>
|
||||||
|
where
|
||||||
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
|
{
|
||||||
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
|
use Op::{NoOp, Transpose};
|
||||||
|
|
||||||
|
match (&a, &b) {
|
||||||
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
|
// Note: We have to reverse the order for CSC matrices
|
||||||
|
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
|
// and calling the operation again without transposition
|
||||||
|
let a_ref: &CscMatrix<T> = a.inner_ref();
|
||||||
|
let b_ref: &CscMatrix<T> = b.inner_ref();
|
||||||
|
let (a, b) = {
|
||||||
|
use Cow::*;
|
||||||
|
match (&a, &b) {
|
||||||
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||||
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||||
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||||
|
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
spmm_csc_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,11 +1,10 @@
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::ops::{Op};
|
use crate::ops::{Op};
|
||||||
use crate::SparseEntryMut;
|
use crate::ops::serial::{OperationError};
|
||||||
use crate::ops::serial::{OperationError, OperationErrorType};
|
|
||||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
use crate::ops::serial::cs::{spmm_cs_prealloc, spmm_cs_dense, spadd_cs_prealloc};
|
||||||
|
|
||||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
pub fn spmm_csr_dense<'a, T>(beta: T,
|
pub fn spmm_csr_dense<'a, T>(beta: T,
|
||||||
|
@ -21,7 +20,7 @@ pub fn spmm_csr_dense<'a, T>(beta: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_dense_<T>(beta: T,
|
fn spmm_csr_dense_<T>(beta: T,
|
||||||
mut c: DMatrixSliceMut<T>,
|
c: DMatrixSliceMut<T>,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
a: Op<&CsrMatrix<T>>,
|
a: Op<&CsrMatrix<T>>,
|
||||||
b: Op<DMatrixSlice<T>>)
|
b: Op<DMatrixSlice<T>>)
|
||||||
|
@ -29,58 +28,7 @@ where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
||||||
match a {
|
|
||||||
Op::NoOp(ref a) => {
|
|
||||||
for j in 0..c.ncols() {
|
|
||||||
let mut c_col_j = c.column_mut(j);
|
|
||||||
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
|
||||||
let mut dot_ij = T::zero();
|
|
||||||
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
|
||||||
let b_contrib =
|
|
||||||
match b {
|
|
||||||
Op::NoOp(ref b) => b.index((k, j)),
|
|
||||||
Op::Transpose(ref b) => b.index((j, k))
|
|
||||||
};
|
|
||||||
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
|
||||||
}
|
|
||||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Op::Transpose(ref a) => {
|
|
||||||
// In this case, we have to pre-multiply C by beta
|
|
||||||
c *= beta;
|
|
||||||
|
|
||||||
for k in 0..a.nrows() {
|
|
||||||
let a_row_k = a.row(k);
|
|
||||||
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
|
|
||||||
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
|
||||||
let mut c_row_i = c.row_mut(i);
|
|
||||||
match b {
|
|
||||||
Op::NoOp(ref b) => {
|
|
||||||
let b_row_k = b.row(k);
|
|
||||||
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
|
||||||
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Op::Transpose(ref b) => {
|
|
||||||
let b_col_k = b.column(k);
|
|
||||||
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
|
||||||
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn spadd_csr_unexpected_entry() -> OperationError {
|
|
||||||
OperationError::from_type_and_message(
|
|
||||||
OperationErrorType::InvalidPattern,
|
|
||||||
String::from("Found entry in `a` that is not present in `c`."))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||||
|
@ -96,70 +44,7 @@ where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
assert_compatible_spadd_dims!(c, a);
|
assert_compatible_spadd_dims!(c, a);
|
||||||
|
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||||
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
|
||||||
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
|
|
||||||
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
|
||||||
// so we only need to sum the value arrays
|
|
||||||
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
|
|
||||||
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
|
||||||
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
match a {
|
|
||||||
Op::NoOp(a) => {
|
|
||||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
|
||||||
if beta != T::one() {
|
|
||||||
for c_ij in c_row_i.values_mut() {
|
|
||||||
*c_ij *= beta.inlined_clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let (mut c_cols, mut c_vals) = c_row_i.cols_and_values_mut();
|
|
||||||
let (a_cols, a_vals) = (a_row_i.col_indices(), a_row_i.values());
|
|
||||||
|
|
||||||
for (a_col, a_val) in a_cols.iter().zip(a_vals) {
|
|
||||||
// TODO: Use exponential search instead of linear search.
|
|
||||||
// If C has substantially more entries in the row than A, then a line search
|
|
||||||
// will needlessly visit many entries in C.
|
|
||||||
let (c_idx, _) = c_cols.iter()
|
|
||||||
.enumerate()
|
|
||||||
.find(|(_, c_col)| *c_col == a_col)
|
|
||||||
.ok_or_else(spadd_csr_unexpected_entry)?;
|
|
||||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
|
||||||
c_cols = &c_cols[c_idx ..];
|
|
||||||
c_vals = &mut c_vals[c_idx ..];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Op::Transpose(a) => {
|
|
||||||
if beta != T::one() {
|
|
||||||
for c_ij in c.values_mut() {
|
|
||||||
*c_ij *= beta.inlined_clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i, a_row_i) in a.row_iter().enumerate() {
|
|
||||||
for (&j, a_val) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
|
||||||
let a_val = a_val.inlined_clone();
|
|
||||||
let alpha = alpha.inlined_clone();
|
|
||||||
match c.index_entry_mut(j, i) {
|
|
||||||
SparseEntryMut::NonZero(c_ji) => { *c_ji += alpha * a_val }
|
|
||||||
SparseEntryMut::Zero => return Err(spadd_csr_unexpected_entry()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn spmm_csr_unexpected_entry() -> OperationError {
|
|
||||||
OperationError::from_type_and_message(
|
|
||||||
OperationErrorType::InvalidPattern,
|
|
||||||
String::from("Found unexpected entry that is not present in `c`."))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
|
@ -179,29 +64,7 @@ where
|
||||||
|
|
||||||
match (&a, &b) {
|
match (&a, &b) {
|
||||||
(NoOp(ref a), NoOp(ref b)) => {
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs)
|
||||||
for c_ij in c_row_i.values_mut() {
|
|
||||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
|
||||||
let b_row_k = b.row(k);
|
|
||||||
let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut();
|
|
||||||
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
|
||||||
for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) {
|
|
||||||
// Determine the location in C to append the value
|
|
||||||
let (c_local_idx, _) = c_row_i_cols.iter()
|
|
||||||
.enumerate()
|
|
||||||
.find(|(_, c_col)| *c_col == j)
|
|
||||||
.ok_or_else(spmm_csr_unexpected_entry)?;
|
|
||||||
|
|
||||||
c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
|
||||||
c_row_i_cols = &c_row_i_cols[c_local_idx ..];
|
|
||||||
c_row_i_values = &mut c_row_i_values[c_local_idx ..];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
},
|
},
|
||||||
_ => {
|
_ => {
|
||||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
|
|
|
@ -49,9 +49,12 @@ macro_rules! assert_compatible_spadd_dims {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod csc;
|
||||||
mod csr;
|
mod csr;
|
||||||
mod pattern;
|
mod pattern;
|
||||||
|
mod cs;
|
||||||
|
|
||||||
|
pub use csc::*;
|
||||||
pub use csr::*;
|
pub use csr::*;
|
||||||
pub use pattern::*;
|
pub use pattern::*;
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,9 @@ pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5 ..= 5;
|
||||||
|
|
||||||
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
pub fn csr_strategy() -> impl Strategy<Value=CsrMatrix<i32>> {
|
||||||
csr(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csr(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
pub fn csc_strategy() -> impl Strategy<Value=CscMatrix<i32>> {
|
||||||
csc(-5 ..= 5, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
csc(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
use crate::common::{csc_strategy, csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
PROPTEST_I32_VALUE_STRATEGY};
|
PROPTEST_I32_VALUE_STRATEGY};
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr_prealloc, spmm_csr_prealloc};
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spmm_csc_dense, spadd_pattern, spmm_pattern,
|
||||||
|
spadd_csr_prealloc, spadd_csc_prealloc,
|
||||||
|
spmm_csr_prealloc, spmm_csc_prealloc};
|
||||||
use nalgebra_sparse::ops::{Op};
|
use nalgebra_sparse::ops::{Op};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
use nalgebra_sparse::csc::CscMatrix;
|
||||||
|
use nalgebra_sparse::proptest::{csc, csr, sparsity_pattern};
|
||||||
use nalgebra_sparse::pattern::SparsityPattern;
|
use nalgebra_sparse::pattern::SparsityPattern;
|
||||||
|
|
||||||
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
use nalgebra::{DMatrix, Scalar, DMatrixSliceMut, DMatrixSlice};
|
||||||
|
@ -23,6 +26,15 @@ fn dense_csr_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
DMatrix::from(&boolean_csr)
|
DMatrix::from(&boolean_csr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Represents the sparsity pattern of a CSC matrix as a dense matrix with 0/1
|
||||||
|
fn dense_csc_pattern(pattern: &SparsityPattern) -> DMatrix<i32> {
|
||||||
|
let boolean_csc = CscMatrix::try_from_pattern_and_values(
|
||||||
|
Arc::new(pattern.clone()),
|
||||||
|
vec![1; pattern.nnz()])
|
||||||
|
.unwrap();
|
||||||
|
DMatrix::from(&boolean_csc)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct SpmmCsrDenseArgs<T: Scalar> {
|
struct SpmmCsrDenseArgs<T: Scalar> {
|
||||||
c: DMatrix<T>,
|
c: DMatrix<T>,
|
||||||
|
@ -32,6 +44,15 @@ struct SpmmCsrDenseArgs<T: Scalar> {
|
||||||
b: Op<DMatrix<T>>,
|
b: Op<DMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SpmmCscDenseArgs<T: Scalar> {
|
||||||
|
c: DMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<CscMatrix<T>>,
|
||||||
|
b: Op<DMatrix<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>> {
|
||||||
|
@ -70,6 +91,21 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
|
/// in an `spmm` operation `C = beta * C + alpha * trans(A) * trans(B)`.
|
||||||
|
fn spmm_csc_dense_args_strategy() -> impl Strategy<Value=SpmmCscDenseArgs<i32>> {
|
||||||
|
spmm_csr_dense_args_strategy()
|
||||||
|
.prop_map(|args| {
|
||||||
|
SpmmCscDenseArgs {
|
||||||
|
c: args.c,
|
||||||
|
beta: args.beta,
|
||||||
|
alpha: args.alpha,
|
||||||
|
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||||
|
b: args.b
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct SpaddCsrArgs<T> {
|
struct SpaddCsrArgs<T> {
|
||||||
c: CsrMatrix<T>,
|
c: CsrMatrix<T>,
|
||||||
|
@ -78,6 +114,14 @@ struct SpaddCsrArgs<T> {
|
||||||
a: Op<CsrMatrix<T>>,
|
a: Op<CsrMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SpaddCscArgs<T> {
|
||||||
|
c: CscMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<CscMatrix<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
|
||||||
|
@ -99,6 +143,16 @@ fn spadd_csr_prealloc_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>>
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn spadd_csc_prealloc_args_strategy() -> impl Strategy<Value=SpaddCscArgs<i32>> {
|
||||||
|
spadd_csr_prealloc_args_strategy()
|
||||||
|
.prop_map(|args| SpaddCscArgs {
|
||||||
|
c: CscMatrix::from(&args.c),
|
||||||
|
beta: args.beta,
|
||||||
|
alpha: args.alpha,
|
||||||
|
a: args.a.map_same_op(|a| CscMatrix::from(&a))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||||
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
||||||
}
|
}
|
||||||
|
@ -150,6 +204,15 @@ struct SpmmCsrArgs<T> {
|
||||||
b: Op<CsrMatrix<T>>,
|
b: Op<CsrMatrix<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SpmmCscArgs<T> {
|
||||||
|
c: CscMatrix<T>,
|
||||||
|
beta: T,
|
||||||
|
alpha: T,
|
||||||
|
a: Op<CscMatrix<T>>,
|
||||||
|
b: Op<CscMatrix<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
spmm_pattern_strategy()
|
spmm_pattern_strategy()
|
||||||
.prop_flat_map(|(a_pattern, b_pattern)| {
|
.prop_flat_map(|(a_pattern, b_pattern)| {
|
||||||
|
@ -181,6 +244,21 @@ fn spmm_csr_prealloc_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn spmm_csc_prealloc_args_strategy() -> impl Strategy<Value=SpmmCscArgs<i32>> {
|
||||||
|
// Note: Converting from CSR is simple, but might be significantly slower than
|
||||||
|
// writing a common implementation that can be shared between CSR and CSC args
|
||||||
|
spmm_csr_prealloc_args_strategy()
|
||||||
|
.prop_map(|args| {
|
||||||
|
SpmmCscArgs {
|
||||||
|
c: CscMatrix::from(&args.c),
|
||||||
|
beta: args.beta,
|
||||||
|
alpha: args.alpha,
|
||||||
|
a: args.a.map_same_op(|a| CscMatrix::from(&a)),
|
||||||
|
b: args.b.map_same_op(|b| CscMatrix::from(&b))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to help us call dense GEMM with our `Op` type
|
/// Helper function to help us call dense GEMM with our `Op` type
|
||||||
fn dense_gemm<'a>(beta: i32,
|
fn dense_gemm<'a>(beta: i32,
|
||||||
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||||
|
@ -310,7 +388,7 @@ proptest! {
|
||||||
(a, b)
|
(a, b)
|
||||||
in csr_strategy()
|
in csr_strategy()
|
||||||
.prop_flat_map(|a| {
|
.prop_flat_map(|a| {
|
||||||
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), 40);
|
let b = csr(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||||
(Just(a), b)
|
(Just(a), b)
|
||||||
}))
|
}))
|
||||||
{
|
{
|
||||||
|
@ -500,4 +578,263 @@ proptest! {
|
||||||
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csc_prealloc_test(SpmmCscArgs { c, beta, alpha, a, b }
|
||||||
|
in spmm_csc_prealloc_args_strategy()
|
||||||
|
) {
|
||||||
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
|
let mut c_sparse = c.clone();
|
||||||
|
spmm_csc_prealloc(beta, &mut c_sparse, alpha, a.as_ref(), b.as_ref()).unwrap();
|
||||||
|
|
||||||
|
let mut c_dense = DMatrix::from(&c);
|
||||||
|
let op_a_dense = match a {
|
||||||
|
Op::NoOp(ref a) => DMatrix::from(a),
|
||||||
|
Op::Transpose(ref a) => DMatrix::from(a).transpose(),
|
||||||
|
};
|
||||||
|
let op_b_dense = match b {
|
||||||
|
Op::NoOp(ref b) => DMatrix::from(b),
|
||||||
|
Op::Transpose(ref b) => DMatrix::from(b).transpose(),
|
||||||
|
};
|
||||||
|
c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense;
|
||||||
|
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csc_prealloc_panics_on_dim_mismatch(
|
||||||
|
(alpha, beta, c, a, b)
|
||||||
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
csc_strategy(),
|
||||||
|
op_strategy(csc_strategy()),
|
||||||
|
op_strategy(csc_strategy()))
|
||||||
|
) {
|
||||||
|
// We refer to `A * B` as the "product"
|
||||||
|
let product_rows = match &a {
|
||||||
|
Op::NoOp(ref a) => a.nrows(),
|
||||||
|
Op::Transpose(ref a) => a.ncols(),
|
||||||
|
};
|
||||||
|
let product_cols = match &b {
|
||||||
|
Op::NoOp(ref b) => b.ncols(),
|
||||||
|
Op::Transpose(ref b) => b.nrows(),
|
||||||
|
};
|
||||||
|
// Determine the common dimension in the product
|
||||||
|
// from the perspective of a and b, respectively
|
||||||
|
let product_a_common = match &a {
|
||||||
|
Op::NoOp(ref a) => a.ncols(),
|
||||||
|
Op::Transpose(ref a) => a.nrows(),
|
||||||
|
};
|
||||||
|
let product_b_common = match &b {
|
||||||
|
Op::NoOp(ref b) => b.nrows(),
|
||||||
|
Op::Transpose(ref b) => b.ncols(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let dims_are_compatible = product_rows == c.nrows()
|
||||||
|
&& product_cols == c.ncols()
|
||||||
|
&& product_a_common == product_b_common;
|
||||||
|
|
||||||
|
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||||
|
// skip the test, so we assume that they are not.
|
||||||
|
prop_assume!(!dims_are_compatible);
|
||||||
|
|
||||||
|
let result = catch_unwind(|| {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spmm_csc_prealloc(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref()).unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
prop_assert!(result.is_err(),
|
||||||
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_mul_csc(
|
||||||
|
// a and b have dimensions compatible for multiplication
|
||||||
|
(a, b)
|
||||||
|
in csc_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let max_nnz = PROPTEST_MAX_NNZ;
|
||||||
|
let cols = PROPTEST_MATRIX_DIM;
|
||||||
|
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.ncols()), cols, max_nnz);
|
||||||
|
(Just(a), b)
|
||||||
|
})
|
||||||
|
.prop_map(|(a, b)| {
|
||||||
|
println!("a: {} x {}, b: {} x {}", a.nrows(), a.ncols(), b.nrows(), b.ncols());
|
||||||
|
(a, b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
assert_eq!(a.ncols(), b.nrows());
|
||||||
|
// We use the dense result as the ground truth for the arithmetic result
|
||||||
|
let c_dense = DMatrix::from(&a) * DMatrix::from(&b);
|
||||||
|
// However, it's not enough only to cover the dense result, we also need to verify the
|
||||||
|
// sparsity pattern. We can determine the exact sparsity pattern by using
|
||||||
|
// dense arithmetic with positive integer values and extracting positive entries.
|
||||||
|
let c_dense_pattern = dense_csc_pattern(a.pattern()) * dense_csc_pattern(b.pattern());
|
||||||
|
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
|
||||||
|
|
||||||
|
// Check each combination of owned matrices and references
|
||||||
|
let c_owned_owned = a.clone() * b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_owned_ref = a.clone() * &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_owned = &a * b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_ref = &a * &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csc_dense_agrees_with_dense_result(
|
||||||
|
SpmmCscDenseArgs { c, beta, alpha, a, b }
|
||||||
|
in spmm_csc_dense_args_strategy()
|
||||||
|
) {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spmm_csc_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref());
|
||||||
|
|
||||||
|
let mut gemm_result = c.clone();
|
||||||
|
let a_dense = a.map_same_op(|a| DMatrix::from(&a));
|
||||||
|
dense_gemm(beta, &mut gemm_result, alpha, a_dense.as_ref(), b.as_ref());
|
||||||
|
|
||||||
|
prop_assert_eq!(spmm_result, gemm_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spmm_csc_dense_panics_on_dim_mismatch(
|
||||||
|
(alpha, beta, c, a, b)
|
||||||
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
dense_strategy(),
|
||||||
|
op_strategy(csc_strategy()),
|
||||||
|
op_strategy(dense_strategy()))
|
||||||
|
) {
|
||||||
|
// We refer to `A * B` as the "product"
|
||||||
|
let product_rows = match &a {
|
||||||
|
Op::NoOp(ref a) => a.nrows(),
|
||||||
|
Op::Transpose(ref a) => a.ncols(),
|
||||||
|
};
|
||||||
|
let product_cols = match &b {
|
||||||
|
Op::NoOp(ref b) => b.ncols(),
|
||||||
|
Op::Transpose(ref b) => b.nrows(),
|
||||||
|
};
|
||||||
|
// Determine the common dimension in the product
|
||||||
|
// from the perspective of a and b, respectively
|
||||||
|
let product_a_common = match &a {
|
||||||
|
Op::NoOp(ref a) => a.ncols(),
|
||||||
|
Op::Transpose(ref a) => a.nrows(),
|
||||||
|
};
|
||||||
|
let product_b_common = match &b {
|
||||||
|
Op::NoOp(ref b) => b.nrows(),
|
||||||
|
Op::Transpose(ref b) => b.ncols()
|
||||||
|
};
|
||||||
|
|
||||||
|
let dims_are_compatible = product_rows == c.nrows()
|
||||||
|
&& product_cols == c.ncols()
|
||||||
|
&& product_a_common == product_b_common;
|
||||||
|
|
||||||
|
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||||
|
// skip the test, so we assume that they are not.
|
||||||
|
prop_assume!(!dims_are_compatible);
|
||||||
|
|
||||||
|
let result = catch_unwind(|| {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spmm_csc_dense(beta, &mut spmm_result, alpha, a.as_ref(), b.as_ref());
|
||||||
|
});
|
||||||
|
|
||||||
|
prop_assert!(result.is_err(),
|
||||||
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spadd_csc_prealloc_test(SpaddCscArgs { c, beta, alpha, a } in spadd_csc_prealloc_args_strategy()) {
|
||||||
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
|
|
||||||
|
let mut c_sparse = c.clone();
|
||||||
|
spadd_csc_prealloc(beta, &mut c_sparse, alpha, a.as_ref()).unwrap();
|
||||||
|
|
||||||
|
let mut c_dense = DMatrix::from(&c);
|
||||||
|
let op_a_dense = match a {
|
||||||
|
Op::NoOp(a) => DMatrix::from(&a),
|
||||||
|
Op::Transpose(a) => DMatrix::from(&a).transpose(),
|
||||||
|
};
|
||||||
|
c_dense = beta * c_dense + alpha * &op_a_dense;
|
||||||
|
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spadd_csc_prealloc_panics_on_dim_mismatch(
|
||||||
|
(alpha, beta, c, op_a)
|
||||||
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
csc_strategy(),
|
||||||
|
op_strategy(csc_strategy()))
|
||||||
|
) {
|
||||||
|
let op_a_rows = match &op_a {
|
||||||
|
&Op::NoOp(ref a) => a.nrows(),
|
||||||
|
&Op::Transpose(ref a) => a.ncols()
|
||||||
|
};
|
||||||
|
let op_a_cols = match &op_a {
|
||||||
|
&Op::NoOp(ref a) => a.ncols(),
|
||||||
|
&Op::Transpose(ref a) => a.nrows()
|
||||||
|
};
|
||||||
|
|
||||||
|
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
|
||||||
|
|
||||||
|
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||||
|
// skip the test, so we assume that they are not.
|
||||||
|
prop_assume!(!dims_are_compatible);
|
||||||
|
|
||||||
|
let result = catch_unwind(|| {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spadd_csc_prealloc(beta, &mut spmm_result, alpha, op_a.as_ref()).unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
prop_assert!(result.is_err(),
|
||||||
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn csc_add_csc(
|
||||||
|
// a and b have the same dimensions
|
||||||
|
(a, b)
|
||||||
|
in csc_strategy()
|
||||||
|
.prop_flat_map(|a| {
|
||||||
|
let b = csc(PROPTEST_I32_VALUE_STRATEGY, Just(a.nrows()), Just(a.ncols()), PROPTEST_MAX_NNZ);
|
||||||
|
(Just(a), b)
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
// We use the dense result as the ground truth for the arithmetic result
|
||||||
|
let c_dense = DMatrix::from(&a) + DMatrix::from(&b);
|
||||||
|
// However, it's not enough only to cover the dense result, we also need to verify the
|
||||||
|
// sparsity pattern. We can determine the exact sparsity pattern by using
|
||||||
|
// dense arithmetic with positive integer values and extracting positive entries.
|
||||||
|
let c_dense_pattern = dense_csc_pattern(a.pattern()) + dense_csc_pattern(b.pattern());
|
||||||
|
let c_pattern = CscMatrix::from(&c_dense_pattern).pattern().clone();
|
||||||
|
|
||||||
|
// Check each combination of owned matrices and references
|
||||||
|
let c_owned_owned = a.clone() + b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_owned_ref = a.clone() + &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_owned_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_owned_ref.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_owned = &a + b.clone();
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_owned), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_owned.pattern(), &c_pattern);
|
||||||
|
|
||||||
|
let c_ref_ref = &a + &b;
|
||||||
|
prop_assert_eq!(&DMatrix::from(&c_ref_ref), &c_dense);
|
||||||
|
prop_assert_eq!(c_ref_ref.pattern(), &c_pattern);
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue