Refactor ops to use new Op type instead of separate Transpose flag
This commit is contained in:
parent
c6a8fcdee2
commit
fe8592fde1
|
@ -5,7 +5,7 @@ use crate::ops::serial::{spadd_csr, spadd_pattern, spmm_pattern, spmm_csr};
|
||||||
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
use nalgebra::{ClosedAdd, ClosedMul, Scalar};
|
||||||
use num_traits::{Zero, One};
|
use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use crate::ops::Transpose;
|
use crate::ops::{Op};
|
||||||
|
|
||||||
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
impl<'a, T> Add<&'a CsrMatrix<T>> for &'a CsrMatrix<T>
|
||||||
where
|
where
|
||||||
|
@ -21,8 +21,8 @@ where
|
||||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||||
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
let mut result = CsrMatrix::try_from_pattern_and_values(Arc::new(pattern), values)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
spadd_csr(&mut result, T::zero(), T::one(), Transpose(false), &self).unwrap();
|
spadd_csr(&mut result, T::zero(), T::one(), Op::NoOp(&self)).unwrap();
|
||||||
spadd_csr(&mut result, T::one(), T::one(), Transpose(false), &rhs).unwrap();
|
spadd_csr(&mut result, T::one(), T::one(), Op::NoOp(&rhs)).unwrap();
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ where
|
||||||
|
|
||||||
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
fn add(mut self, rhs: &'a CsrMatrix<T>) -> Self::Output {
|
||||||
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
if Arc::ptr_eq(self.pattern(), rhs.pattern()) {
|
||||||
spadd_csr(&mut self, T::one(), T::one(), Transpose(false), &rhs).unwrap();
|
spadd_csr(&mut self, T::one(), T::one(), Op::NoOp(rhs)).unwrap();
|
||||||
self
|
self
|
||||||
} else {
|
} else {
|
||||||
&self + rhs
|
&self + rhs
|
||||||
|
@ -93,10 +93,8 @@ impl_matrix_mul!(<'a>(a: &'a CsrMatrix<T>, b: &'a CsrMatrix<T>) -> CsrMatrix<T>
|
||||||
spmm_csr(&mut result,
|
spmm_csr(&mut result,
|
||||||
T::zero(),
|
T::zero(),
|
||||||
T::one(),
|
T::one(),
|
||||||
Transpose(false),
|
Op::NoOp(a),
|
||||||
a,
|
Op::NoOp(b))
|
||||||
Transpose(false),
|
|
||||||
b)
|
|
||||||
.expect("Internal error: spmm failed (please debug).");
|
.expect("Internal error: spmm failed (please debug).");
|
||||||
result
|
result
|
||||||
});
|
});
|
||||||
|
|
|
@ -4,14 +4,54 @@ mod impl_std_ops;
|
||||||
pub mod serial;
|
pub mod serial;
|
||||||
|
|
||||||
/// TODO
|
/// TODO
|
||||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct Transpose(pub bool);
|
pub enum Op<T> {
|
||||||
|
|
||||||
impl Transpose {
|
|
||||||
/// TODO
|
/// TODO
|
||||||
pub fn to_bool(&self) -> bool {
|
NoOp(T),
|
||||||
self.0
|
/// TODO
|
||||||
|
Transpose(T),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Op<T> {
|
||||||
|
/// TODO
|
||||||
|
pub fn inner_ref(&self) -> &T {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => &obj,
|
||||||
|
Op::Transpose(obj) => &obj
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn as_ref(&self) -> Op<&T> {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::NoOp(&obj),
|
||||||
|
Op::Transpose(obj) => Op::Transpose(&obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
pub fn convert<U>(self) -> Op<U>
|
||||||
|
where T: Into<U>
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::NoOp(obj.into()),
|
||||||
|
Op::Transpose(obj) => Op::Transpose(obj.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO
|
||||||
|
/// TODO: Rewrite the other functions by leveraging this one
|
||||||
|
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
|
||||||
|
match self {
|
||||||
|
Op::NoOp(obj) => Op::NoOp(f(obj)),
|
||||||
|
Op::Transpose(obj) => Op::Transpose(f(obj))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> From<T> for Op<T> {
|
||||||
|
fn from(obj: T) -> Self {
|
||||||
|
Self::NoOp(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::csr::CsrMatrix;
|
use crate::csr::CsrMatrix;
|
||||||
use crate::ops::{Transpose};
|
use crate::ops::{Op};
|
||||||
use crate::SparseEntryMut;
|
use crate::SparseEntryMut;
|
||||||
use crate::ops::serial::{OperationError, OperationErrorType};
|
use crate::ops::serial::{OperationError, OperationErrorType};
|
||||||
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
use nalgebra::{Scalar, DMatrixSlice, ClosedAdd, ClosedMul, DMatrixSliceMut};
|
||||||
|
@ -7,65 +7,71 @@ use num_traits::{Zero, One};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * trans(A) * trans(B)`.
|
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||||
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
pub fn spmm_csr_dense<'a, T>(c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<&CsrMatrix<T>>,
|
||||||
a: &CsrMatrix<T>,
|
b: Op<impl Into<DMatrixSlice<'a, T>>>)
|
||||||
trans_b: Transpose,
|
|
||||||
b: impl Into<DMatrixSlice<'a, T>>)
|
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
spmm_csr_dense_(c.into(), beta, alpha, trans_a, a, trans_b, b.into())
|
let b = b.convert();
|
||||||
|
spmm_csr_dense_(c.into(), beta, alpha, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>,
|
fn spmm_csr_dense_<T>(mut c: DMatrixSliceMut<T>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<&CsrMatrix<T>>,
|
||||||
a: &CsrMatrix<T>,
|
b: Op<DMatrixSlice<T>>)
|
||||||
trans_b: Transpose,
|
|
||||||
b: DMatrixSlice<T>)
|
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
if trans_a.to_bool() {
|
match a {
|
||||||
// In this case, we have to pre-multiply C by beta
|
Op::Transpose(ref a) => {
|
||||||
c *= beta;
|
// In this case, we have to pre-multiply C by beta
|
||||||
|
c *= beta;
|
||||||
|
|
||||||
for k in 0..a.nrows() {
|
for k in 0..a.nrows() {
|
||||||
let a_row_k = a.row(k);
|
let a_row_k = a.row(k);
|
||||||
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
|
for (&i, a_ki) in a_row_k.col_indices().iter().zip(a_row_k.values()) {
|
||||||
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
||||||
let mut c_row_i = c.row_mut(i);
|
let mut c_row_i = c.row_mut(i);
|
||||||
if trans_b.to_bool() {
|
match b {
|
||||||
let b_col_k = b.column(k);
|
Op::NoOp(ref b) => {
|
||||||
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
let b_row_k = b.row(k);
|
||||||
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||||
}
|
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||||
} else {
|
}
|
||||||
let b_row_k = b.row(k);
|
},
|
||||||
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
Op::Transpose(ref b) => {
|
||||||
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
let b_col_k = b.column(k);
|
||||||
|
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||||
|
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
} else {
|
Op::NoOp(ref a) => {
|
||||||
for j in 0..c.ncols() {
|
for j in 0..c.ncols() {
|
||||||
let mut c_col_j = c.column_mut(j);
|
let mut c_col_j = c.column_mut(j);
|
||||||
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.row_iter()) {
|
||||||
let mut dot_ij = T::zero();
|
let mut dot_ij = T::zero();
|
||||||
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
let b_contrib =
|
let b_contrib =
|
||||||
if trans_b.to_bool() { b.index((j, k)) } else { b.index((k, j)) };
|
match b {
|
||||||
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
Op::NoOp(ref b) => b.index((k, j)),
|
||||||
|
Op::Transpose(ref b) => b.index((j, k))
|
||||||
|
};
|
||||||
|
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||||
|
}
|
||||||
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
||||||
}
|
}
|
||||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone() + alpha.inlined_clone() * dot_ij;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,32 +83,31 @@ fn spadd_csr_unexpected_entry() -> OperationError {
|
||||||
String::from("Found entry in `a` that is not present in `c`."))
|
String::from("Found entry in `a` that is not present in `c`."))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sparse matrix addition `C <- beta * C + alpha * trans(A)`.
|
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||||
///
|
///
|
||||||
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||||
/// returned.
|
/// returned.
|
||||||
pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
|
pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<&CsrMatrix<T>>)
|
||||||
a: &CsrMatrix<T>)
|
|
||||||
-> Result<(), OperationError>
|
-> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
assert_compatible_spadd_dims!(c, a, trans_a);
|
assert_compatible_spadd_dims!(c, a);
|
||||||
|
|
||||||
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
||||||
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
|
if Arc::ptr_eq(&c.pattern(), &a.inner_ref().pattern()) {
|
||||||
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
// Special fast path: The two matrices have *exactly* the same sparsity pattern,
|
||||||
// so we only need to sum the value arrays
|
// so we only need to sum the value arrays
|
||||||
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.values()) {
|
for (c_ij, a_ij) in c.values_mut().iter_mut().zip(a.inner_ref().values()) {
|
||||||
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
let (alpha, beta) = (alpha.inlined_clone(), beta.inlined_clone());
|
||||||
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
*c_ij = beta * c_ij.inlined_clone() + alpha * a_ij.inlined_clone();
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
if trans_a.to_bool()
|
if let Op::Transpose(a) = a
|
||||||
{
|
{
|
||||||
if beta != T::one() {
|
if beta != T::one() {
|
||||||
for c_ij in c.values_mut() {
|
for c_ij in c.values_mut() {
|
||||||
|
@ -120,7 +125,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if let Op::NoOp(a) = a {
|
||||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||||
if beta != T::one() {
|
if beta != T::one() {
|
||||||
for c_ij in c_row_i.values_mut() {
|
for c_ij in c_row_i.values_mut() {
|
||||||
|
@ -160,56 +165,61 @@ pub fn spmm_csr<'a, T>(
|
||||||
c: &mut CsrMatrix<T>,
|
c: &mut CsrMatrix<T>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<&CsrMatrix<T>>,
|
||||||
a: &CsrMatrix<T>,
|
b: Op<&CsrMatrix<T>>)
|
||||||
trans_b: Transpose,
|
|
||||||
b: &CsrMatrix<T>)
|
|
||||||
-> Result<(), OperationError>
|
-> Result<(), OperationError>
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
assert_compatible_spmm_dims!(c, a, b, trans_a, trans_b);
|
assert_compatible_spmm_dims!(c, a, b);
|
||||||
|
|
||||||
if !trans_a.to_bool() && !trans_b.to_bool() {
|
use Op::{NoOp, Transpose};
|
||||||
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
|
||||||
for c_ij in c_row_i.values_mut() {
|
|
||||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
match (&a, &b) {
|
||||||
let b_row_k = b.row(k);
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut();
|
for (mut c_row_i, a_row_i) in c.row_iter_mut().zip(a.row_iter()) {
|
||||||
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
for c_ij in c_row_i.values_mut() {
|
||||||
for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) {
|
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
||||||
// Determine the location in C to append the value
|
}
|
||||||
let (c_local_idx, _) = c_row_i_cols.iter()
|
|
||||||
.enumerate()
|
|
||||||
.find(|(_, c_col)| *c_col == j)
|
|
||||||
.ok_or_else(spmm_csr_unexpected_entry)?;
|
|
||||||
|
|
||||||
c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
for (&k, a_ik) in a_row_i.col_indices().iter().zip(a_row_i.values()) {
|
||||||
c_row_i_cols = &c_row_i_cols[c_local_idx ..];
|
let b_row_k = b.row(k);
|
||||||
c_row_i_values = &mut c_row_i_values[c_local_idx ..];
|
let (mut c_row_i_cols, mut c_row_i_values) = c_row_i.cols_and_values_mut();
|
||||||
|
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||||
|
for (j, b_kj) in b_row_k.col_indices().iter().zip(b_row_k.values()) {
|
||||||
|
// Determine the location in C to append the value
|
||||||
|
let (c_local_idx, _) = c_row_i_cols.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, c_col)| *c_col == j)
|
||||||
|
.ok_or_else(spmm_csr_unexpected_entry)?;
|
||||||
|
|
||||||
|
c_row_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||||
|
c_row_i_cols = &c_row_i_cols[c_local_idx ..];
|
||||||
|
c_row_i_values = &mut c_row_i_values[c_local_idx ..];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
Ok(())
|
||||||
Ok(())
|
},
|
||||||
} else {
|
_ => {
|
||||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||||
// and calling the operation again without transposition
|
// and calling the operation again without transposition
|
||||||
// TODO: At least use workspaces to allow control of allocations. Maybe
|
// TODO: At least use workspaces to allow control of allocations. Maybe
|
||||||
// consider implementing certain patterns (like A^T * B) explicitly
|
// consider implementing certain patterns (like A^T * B) explicitly
|
||||||
let (a, b) = {
|
let a_ref: &CsrMatrix<T> = a.inner_ref();
|
||||||
use Cow::*;
|
let b_ref: &CsrMatrix<T> = b.inner_ref();
|
||||||
match (trans_a, trans_b) {
|
let (a, b) = {
|
||||||
(Transpose(false), Transpose(false)) => unreachable!(),
|
use Cow::*;
|
||||||
(Transpose(true), Transpose(false)) => (Owned(a.transpose()), Borrowed(b)),
|
match (&a, &b) {
|
||||||
(Transpose(false), Transpose(true)) => (Borrowed(a), Owned(b.transpose())),
|
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||||
(Transpose(true), Transpose(true)) => (Owned(a.transpose()), Owned(b.transpose()))
|
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||||
}
|
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||||
};
|
(Transpose(ref a), Transpose(ref b)) => (Owned(a.transpose()), Owned(b.transpose()))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
spmm_csr(c, beta, alpha, Transpose(false), a.as_ref(), Transpose(false), b.as_ref())
|
spmm_csr(c, beta, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,46 +2,47 @@
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
macro_rules! assert_compatible_spmm_dims {
|
macro_rules! assert_compatible_spmm_dims {
|
||||||
($c:expr, $a:expr, $b:expr, $trans_a:expr, $trans_b:expr) => {
|
($c:expr, $a:expr, $b:expr) => {
|
||||||
use crate::ops::Transpose;
|
{
|
||||||
match ($trans_a, $trans_b) {
|
use crate::ops::Op::{NoOp, Transpose};
|
||||||
(Transpose(false), Transpose(false)) => {
|
match (&$a, &$b) {
|
||||||
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
(NoOp(ref a), NoOp(ref b)) => {
|
||||||
assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()");
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
assert_eq!($a.ncols(), $b.nrows(), "A.ncols() != B.nrows()");
|
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||||
},
|
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
||||||
(Transpose(true), Transpose(false)) => {
|
},
|
||||||
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
(Transpose(ref a), NoOp(ref b)) => {
|
||||||
assert_eq!($c.ncols(), $b.ncols(), "C.ncols() != B.ncols()");
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
assert_eq!($a.nrows(), $b.nrows(), "A.nrows() != B.nrows()");
|
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||||
},
|
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
||||||
(Transpose(false), Transpose(true)) => {
|
},
|
||||||
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
(NoOp(ref a), Transpose(ref b)) => {
|
||||||
assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()");
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
assert_eq!($a.ncols(), $b.ncols(), "A.ncols() != B.ncols()");
|
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||||
},
|
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
||||||
(Transpose(true), Transpose(true)) => {
|
},
|
||||||
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
(Transpose(ref a), Transpose(ref b)) => {
|
||||||
assert_eq!($c.ncols(), $b.nrows(), "C.ncols() != B.nrows()");
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
assert_eq!($a.nrows(), $b.ncols(), "A.nrows() != B.ncols()");
|
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||||
|
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
macro_rules! assert_compatible_spadd_dims {
|
macro_rules! assert_compatible_spadd_dims {
|
||||||
($c:expr, $a:expr, $trans_a:expr) => {
|
($c:expr, $a:expr) => {
|
||||||
use crate::ops::Transpose;
|
use crate::ops::Op;
|
||||||
match $trans_a {
|
match $a {
|
||||||
Transpose(false) => {
|
Op::NoOp(a) => {
|
||||||
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||||
assert_eq!($c.ncols(), $a.ncols(), "C.ncols() != A.ncols()");
|
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
|
||||||
},
|
},
|
||||||
Transpose(true) => {
|
Op::Transpose(a) => {
|
||||||
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||||
assert_eq!($c.ncols(), $a.nrows(), "C.ncols() != A.nrows()");
|
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
use crate::common::{csr_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ,
|
||||||
PROPTEST_I32_VALUE_STRATEGY};
|
PROPTEST_I32_VALUE_STRATEGY};
|
||||||
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
use nalgebra_sparse::ops::serial::{spmm_csr_dense, spadd_pattern, spmm_pattern, spadd_csr, spmm_csr};
|
||||||
use nalgebra_sparse::ops::{Transpose};
|
use nalgebra_sparse::ops::{Op};
|
||||||
use nalgebra_sparse::csr::CsrMatrix;
|
use nalgebra_sparse::csr::CsrMatrix;
|
||||||
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
use nalgebra_sparse::proptest::{csr, sparsity_pattern};
|
||||||
use nalgebra_sparse::pattern::SparsityPattern;
|
use nalgebra_sparse::pattern::SparsityPattern;
|
||||||
|
@ -28,10 +28,8 @@ struct SpmmCsrDenseArgs<T: Scalar> {
|
||||||
c: DMatrix<T>,
|
c: DMatrix<T>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<CsrMatrix<T>>,
|
||||||
a: CsrMatrix<T>,
|
b: Op<DMatrix<T>>,
|
||||||
trans_b: Transpose,
|
|
||||||
b: DMatrix<T>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
/// Returns matrices C, A and B with compatible dimensions such that it can be used
|
||||||
|
@ -48,10 +46,10 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
|
(c_matrix_strategy, common_dim, trans_strategy.clone(), trans_strategy.clone())
|
||||||
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
|
.prop_flat_map(move |(c, common_dim, trans_a, trans_b)| {
|
||||||
let a_shape =
|
let a_shape =
|
||||||
if trans_a.to_bool() { (common_dim, c.nrows()) }
|
if trans_a { (common_dim, c.nrows()) }
|
||||||
else { (c.nrows(), common_dim) };
|
else { (c.nrows(), common_dim) };
|
||||||
let b_shape =
|
let b_shape =
|
||||||
if trans_b.to_bool() { (c.ncols(), common_dim) }
|
if trans_b { (c.ncols(), common_dim) }
|
||||||
else { (common_dim, c.ncols()) };
|
else { (common_dim, c.ncols()) };
|
||||||
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz);
|
let a = csr(value_strategy.clone(), Just(a_shape.0), Just(a_shape.1), max_nnz);
|
||||||
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
let b = matrix(value_strategy.clone(), b_shape.0, b_shape.1);
|
||||||
|
@ -66,10 +64,8 @@ fn spmm_csr_dense_args_strategy() -> impl Strategy<Value=SpmmCsrDenseArgs<i32>>
|
||||||
c,
|
c,
|
||||||
beta,
|
beta,
|
||||||
alpha,
|
alpha,
|
||||||
trans_a,
|
a: if trans_a { Op::Transpose(a) } else { Op::NoOp(a) },
|
||||||
a,
|
b: if trans_b { Op::Transpose(b) } else { Op::NoOp(b) },
|
||||||
trans_b,
|
|
||||||
b,
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -79,14 +75,13 @@ struct SpaddCsrArgs<T> {
|
||||||
c: CsrMatrix<T>,
|
c: CsrMatrix<T>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<CsrMatrix<T>>,
|
||||||
a: CsrMatrix<T>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
let value_strategy = PROPTEST_I32_VALUE_STRATEGY;
|
||||||
|
|
||||||
spadd_build_pattern_strategy()
|
spadd_pattern_strategy()
|
||||||
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
.prop_flat_map(move |(a_pattern, b_pattern)| {
|
||||||
let c_pattern = spadd_pattern(&a_pattern, &b_pattern);
|
let c_pattern = spadd_pattern(&a_pattern, &b_pattern);
|
||||||
|
|
||||||
|
@ -99,8 +94,8 @@ fn spadd_csr_args_strategy() -> impl Strategy<Value=SpaddCsrArgs<i32>> {
|
||||||
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
|
let c = CsrMatrix::try_from_pattern_and_values(Arc::new(c_pattern), c_values).unwrap();
|
||||||
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
|
let a = CsrMatrix::try_from_pattern_and_values(Arc::new(a_pattern), a_values).unwrap();
|
||||||
|
|
||||||
let a = if trans_a.to_bool() { a.transpose() } else { a };
|
let a = if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) };
|
||||||
SpaddCsrArgs { c, beta, alpha, trans_a, a }
|
SpaddCsrArgs { c, beta, alpha, a }
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,8 +103,20 @@ fn dense_strategy() -> impl Strategy<Value=DMatrix<i32>> {
|
||||||
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
matrix(PROPTEST_I32_VALUE_STRATEGY, PROPTEST_MATRIX_DIM, PROPTEST_MATRIX_DIM)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn trans_strategy() -> impl Strategy<Value=Transpose> + Clone {
|
fn trans_strategy() -> impl Strategy<Value=bool> + Clone {
|
||||||
proptest::bool::ANY.prop_map(Transpose)
|
proptest::bool::ANY
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wraps the values of the given strategy in `Op`, producing both transposed and non-transposed
|
||||||
|
/// values.
|
||||||
|
fn op_strategy<S: Strategy>(strategy: S) -> impl Strategy<Value=Op<S::Value>> {
|
||||||
|
let is_transposed = proptest::bool::ANY;
|
||||||
|
(strategy, is_transposed)
|
||||||
|
.prop_map(|(obj, is_trans)| if is_trans {
|
||||||
|
Op::Transpose(obj)
|
||||||
|
} else {
|
||||||
|
Op::NoOp(obj)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
||||||
|
@ -117,7 +124,7 @@ fn pattern_strategy() -> impl Strategy<Value=SparsityPattern> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs pairs (a, b) where a and b have the same dimensions
|
/// Constructs pairs (a, b) where a and b have the same dimensions
|
||||||
fn spadd_build_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
fn spadd_pattern_strategy() -> impl Strategy<Value=(SparsityPattern, SparsityPattern)> {
|
||||||
pattern_strategy()
|
pattern_strategy()
|
||||||
.prop_flat_map(|a| {
|
.prop_flat_map(|a| {
|
||||||
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ);
|
let b = sparsity_pattern(Just(a.major_dim()), Just(a.minor_dim()), PROPTEST_MAX_NNZ);
|
||||||
|
@ -139,10 +146,8 @@ struct SpmmCsrArgs<T> {
|
||||||
c: CsrMatrix<T>,
|
c: CsrMatrix<T>,
|
||||||
beta: T,
|
beta: T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
trans_a: Transpose,
|
a: Op<CsrMatrix<T>>,
|
||||||
a: CsrMatrix<T>,
|
b: Op<CsrMatrix<T>>,
|
||||||
trans_b: Transpose,
|
|
||||||
b: CsrMatrix<T>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
|
@ -170,10 +175,8 @@ fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
c,
|
c,
|
||||||
beta,
|
beta,
|
||||||
alpha,
|
alpha,
|
||||||
trans_a,
|
a: if trans_a { Op::Transpose(a.transpose()) } else { Op::NoOp(a) },
|
||||||
a: if trans_a.to_bool() { a.transpose() } else { a },
|
b: if trans_b { Op::Transpose(b.transpose()) } else { Op::NoOp(b) }
|
||||||
trans_b,
|
|
||||||
b: if trans_b.to_bool() { b.transpose() } else { b }
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -182,52 +185,67 @@ fn spmm_csr_args_strategy() -> impl Strategy<Value=SpmmCsrArgs<i32>> {
|
||||||
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
fn dense_gemm<'a>(c: impl Into<DMatrixSliceMut<'a, i32>>,
|
||||||
beta: i32,
|
beta: i32,
|
||||||
alpha: i32,
|
alpha: i32,
|
||||||
trans_a: Transpose,
|
a: Op<impl Into<DMatrixSlice<'a, i32>>>,
|
||||||
a: impl Into<DMatrixSlice<'a, i32>>,
|
b: Op<impl Into<DMatrixSlice<'a, i32>>>)
|
||||||
trans_b: Transpose,
|
|
||||||
b: impl Into<DMatrixSlice<'a, i32>>)
|
|
||||||
{
|
{
|
||||||
let mut c = c.into();
|
let mut c = c.into();
|
||||||
let a = a.into();
|
let a = a.convert();
|
||||||
let b = b.into();
|
let b = b.convert();
|
||||||
|
|
||||||
match (trans_a, trans_b) {
|
use Op::{NoOp, Transpose};
|
||||||
(Transpose(false), Transpose(false)) => c.gemm(alpha, &a, &b, beta),
|
match (a, b) {
|
||||||
(Transpose(true), Transpose(false)) => c.gemm(alpha, &a.transpose(), &b, beta),
|
(NoOp(a), NoOp(b)) => c.gemm(alpha, &a, &b, beta),
|
||||||
(Transpose(false), Transpose(true)) => c.gemm(alpha, &a, &b.transpose(), beta),
|
(Transpose(a), NoOp(b)) => c.gemm(alpha, &a.transpose(), &b, beta),
|
||||||
(Transpose(true), Transpose(true)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
|
(NoOp(a), Transpose(b)) => c.gemm(alpha, &a, &b.transpose(), beta),
|
||||||
};
|
(Transpose(a), Transpose(b)) => c.gemm(alpha, &a.transpose(), &b.transpose(), beta)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proptest! {
|
proptest! {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_dense_agrees_with_dense_result(
|
fn spmm_csr_dense_agrees_with_dense_result(
|
||||||
SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b }
|
SpmmCsrDenseArgs { c, beta, alpha, a, b }
|
||||||
in spmm_csr_dense_args_strategy()
|
in spmm_csr_dense_args_strategy()
|
||||||
) {
|
) {
|
||||||
let mut spmm_result = c.clone();
|
let mut spmm_result = c.clone();
|
||||||
spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b);
|
spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref());
|
||||||
|
|
||||||
let mut gemm_result = c.clone();
|
let mut gemm_result = c.clone();
|
||||||
dense_gemm(&mut gemm_result, beta, alpha, trans_a, &DMatrix::from(&a), trans_b, &b);
|
let a_dense = a.map_same_op(|a| DMatrix::from(&a));
|
||||||
|
dense_gemm(&mut gemm_result, beta, alpha, a_dense.as_ref(), b.as_ref());
|
||||||
|
|
||||||
prop_assert_eq!(spmm_result, gemm_result);
|
prop_assert_eq!(spmm_result, gemm_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_dense_panics_on_dim_mismatch(
|
fn spmm_csr_dense_panics_on_dim_mismatch(
|
||||||
(alpha, beta, c, a, b, trans_a, trans_b)
|
(alpha, beta, c, a, b)
|
||||||
in (-5 ..= 5, -5 ..= 5, dense_strategy(), csr_strategy(),
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
dense_strategy(), trans_strategy(), trans_strategy())
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
dense_strategy(),
|
||||||
|
op_strategy(csr_strategy()),
|
||||||
|
op_strategy(dense_strategy()))
|
||||||
) {
|
) {
|
||||||
// We refer to `A * B` as the "product"
|
// We refer to `A * B` as the "product"
|
||||||
let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
let product_rows = match &a {
|
||||||
let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() };
|
Op::NoOp(ref a) => a.nrows(),
|
||||||
|
Op::Transpose(ref a) => a.ncols(),
|
||||||
|
};
|
||||||
|
let product_cols = match &b {
|
||||||
|
Op::NoOp(ref b) => b.ncols(),
|
||||||
|
Op::Transpose(ref b) => b.nrows(),
|
||||||
|
};
|
||||||
// Determine the common dimension in the product
|
// Determine the common dimension in the product
|
||||||
// from the perspective of a and b, respectively
|
// from the perspective of a and b, respectively
|
||||||
let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
let product_a_common = match &a {
|
||||||
let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() };
|
Op::NoOp(ref a) => a.ncols(),
|
||||||
|
Op::Transpose(ref a) => a.nrows(),
|
||||||
|
};
|
||||||
|
let product_b_common = match &b {
|
||||||
|
Op::NoOp(ref b) => b.nrows(),
|
||||||
|
Op::Transpose(ref b) => b.ncols()
|
||||||
|
};
|
||||||
|
|
||||||
let dims_are_compatible = product_rows == c.nrows()
|
let dims_are_compatible = product_rows == c.nrows()
|
||||||
&& product_cols == c.ncols()
|
&& product_cols == c.ncols()
|
||||||
|
@ -239,7 +257,7 @@ proptest! {
|
||||||
|
|
||||||
let result = catch_unwind(|| {
|
let result = catch_unwind(|| {
|
||||||
let mut spmm_result = c.clone();
|
let mut spmm_result = c.clone();
|
||||||
spmm_csr_dense(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b);
|
spmm_csr_dense(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref());
|
||||||
});
|
});
|
||||||
|
|
||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
|
@ -247,7 +265,7 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spadd_pattern_test((a, b) in spadd_build_pattern_strategy())
|
fn spadd_pattern_test((a, b) in spadd_pattern_strategy())
|
||||||
{
|
{
|
||||||
// (a, b) are dimensionally compatible patterns
|
// (a, b) are dimensionally compatible patterns
|
||||||
let pattern_result = spadd_pattern(&a, &b);
|
let pattern_result = spadd_pattern(&a, &b);
|
||||||
|
@ -269,16 +287,18 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, trans_a, a } in spadd_csr_args_strategy()) {
|
fn spadd_csr_test(SpaddCsrArgs { c, beta, alpha, a } in spadd_csr_args_strategy()) {
|
||||||
// Test that we get the expected result by comparing to an equivalent dense operation
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
|
|
||||||
let mut c_sparse = c.clone();
|
let mut c_sparse = c.clone();
|
||||||
spadd_csr(&mut c_sparse, beta, alpha, trans_a, &a).unwrap();
|
spadd_csr(&mut c_sparse, beta, alpha, a.as_ref()).unwrap();
|
||||||
|
|
||||||
let mut c_dense = DMatrix::from(&c);
|
let mut c_dense = DMatrix::from(&c);
|
||||||
let op_a_dense = DMatrix::from(&a);
|
let op_a_dense = match a {
|
||||||
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
|
Op::NoOp(a) => DMatrix::from(&a),
|
||||||
|
Op::Transpose(a) => DMatrix::from(&a).transpose(),
|
||||||
|
};
|
||||||
c_dense = beta * c_dense + alpha * &op_a_dense;
|
c_dense = beta * c_dense + alpha * &op_a_dense;
|
||||||
|
|
||||||
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||||
|
@ -343,19 +363,23 @@ proptest! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b }
|
fn spmm_csr_test(SpmmCsrArgs { c, beta, alpha, a, b }
|
||||||
in spmm_csr_args_strategy()
|
in spmm_csr_args_strategy()
|
||||||
) {
|
) {
|
||||||
// Test that we get the expected result by comparing to an equivalent dense operation
|
// Test that we get the expected result by comparing to an equivalent dense operation
|
||||||
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
// (here we give in the C matrix, so the sparsity pattern is essentially fixed)
|
||||||
let mut c_sparse = c.clone();
|
let mut c_sparse = c.clone();
|
||||||
spmm_csr(&mut c_sparse, beta, alpha, trans_a, &a, trans_b, &b).unwrap();
|
spmm_csr(&mut c_sparse, beta, alpha, a.as_ref(), b.as_ref()).unwrap();
|
||||||
|
|
||||||
let mut c_dense = DMatrix::from(&c);
|
let mut c_dense = DMatrix::from(&c);
|
||||||
let op_a_dense = DMatrix::from(&a);
|
let op_a_dense = match a {
|
||||||
let op_a_dense = if trans_a.to_bool() { op_a_dense.transpose() } else { op_a_dense };
|
Op::NoOp(ref a) => DMatrix::from(a),
|
||||||
let op_b_dense = DMatrix::from(&b);
|
Op::Transpose(ref a) => DMatrix::from(a).transpose(),
|
||||||
let op_b_dense = if trans_b.to_bool() { op_b_dense.transpose() } else { op_b_dense };
|
};
|
||||||
|
let op_b_dense = match b {
|
||||||
|
Op::NoOp(ref b) => DMatrix::from(b),
|
||||||
|
Op::Transpose(ref b) => DMatrix::from(b).transpose(),
|
||||||
|
};
|
||||||
c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense;
|
c_dense = beta * c_dense + alpha * &op_a_dense * op_b_dense;
|
||||||
|
|
||||||
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
prop_assert_eq!(&DMatrix::from(&c_sparse), &c_dense);
|
||||||
|
@ -363,22 +387,32 @@ proptest! {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spmm_csr_panics_on_dim_mismatch(
|
fn spmm_csr_panics_on_dim_mismatch(
|
||||||
(alpha, beta, c, a, b, trans_a, trans_b)
|
(alpha, beta, c, a, b)
|
||||||
in (PROPTEST_I32_VALUE_STRATEGY,
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
PROPTEST_I32_VALUE_STRATEGY,
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
csr_strategy(),
|
csr_strategy(),
|
||||||
csr_strategy(),
|
op_strategy(csr_strategy()),
|
||||||
csr_strategy(),
|
op_strategy(csr_strategy()))
|
||||||
trans_strategy(),
|
|
||||||
trans_strategy())
|
|
||||||
) {
|
) {
|
||||||
// We refer to `A * B` as the "product"
|
// We refer to `A * B` as the "product"
|
||||||
let product_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
let product_rows = match &a {
|
||||||
let product_cols = if trans_b.to_bool() { b.nrows() } else { b.ncols() };
|
Op::NoOp(ref a) => a.nrows(),
|
||||||
|
Op::Transpose(ref a) => a.ncols(),
|
||||||
|
};
|
||||||
|
let product_cols = match &b {
|
||||||
|
Op::NoOp(ref b) => b.ncols(),
|
||||||
|
Op::Transpose(ref b) => b.nrows(),
|
||||||
|
};
|
||||||
// Determine the common dimension in the product
|
// Determine the common dimension in the product
|
||||||
// from the perspective of a and b, respectively
|
// from the perspective of a and b, respectively
|
||||||
let product_a_common = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
let product_a_common = match &a {
|
||||||
let product_b_common = if trans_b.to_bool() { b.ncols() } else { b.nrows() };
|
Op::NoOp(ref a) => a.ncols(),
|
||||||
|
Op::Transpose(ref a) => a.nrows(),
|
||||||
|
};
|
||||||
|
let product_b_common = match &b {
|
||||||
|
Op::NoOp(ref b) => b.nrows(),
|
||||||
|
Op::Transpose(ref b) => b.ncols(),
|
||||||
|
};
|
||||||
|
|
||||||
let dims_are_compatible = product_rows == c.nrows()
|
let dims_are_compatible = product_rows == c.nrows()
|
||||||
&& product_cols == c.ncols()
|
&& product_cols == c.ncols()
|
||||||
|
@ -390,7 +424,7 @@ proptest! {
|
||||||
|
|
||||||
let result = catch_unwind(|| {
|
let result = catch_unwind(|| {
|
||||||
let mut spmm_result = c.clone();
|
let mut spmm_result = c.clone();
|
||||||
spmm_csr(&mut spmm_result, beta, alpha, trans_a, &a, trans_b, &b).unwrap();
|
spmm_csr(&mut spmm_result, beta, alpha, a.as_ref(), b.as_ref()).unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
|
@ -399,15 +433,20 @@ proptest! {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spadd_csr_panics_on_dim_mismatch(
|
fn spadd_csr_panics_on_dim_mismatch(
|
||||||
(alpha, beta, c, a, trans_a)
|
(alpha, beta, c, op_a)
|
||||||
in (PROPTEST_I32_VALUE_STRATEGY,
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
PROPTEST_I32_VALUE_STRATEGY,
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
csr_strategy(),
|
csr_strategy(),
|
||||||
csr_strategy(),
|
op_strategy(csr_strategy()))
|
||||||
trans_strategy())
|
|
||||||
) {
|
) {
|
||||||
let op_a_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
let op_a_rows = match &op_a {
|
||||||
let op_a_cols = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
&Op::NoOp(ref a) => a.nrows(),
|
||||||
|
&Op::Transpose(ref a) => a.ncols()
|
||||||
|
};
|
||||||
|
let op_a_cols = match &op_a {
|
||||||
|
&Op::NoOp(ref a) => a.ncols(),
|
||||||
|
&Op::Transpose(ref a) => a.nrows()
|
||||||
|
};
|
||||||
|
|
||||||
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
|
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
|
||||||
|
|
||||||
|
@ -417,7 +456,7 @@ proptest! {
|
||||||
|
|
||||||
let result = catch_unwind(|| {
|
let result = catch_unwind(|| {
|
||||||
let mut spmm_result = c.clone();
|
let mut spmm_result = c.clone();
|
||||||
spadd_csr(&mut spmm_result, beta, alpha, trans_a, &a).unwrap();
|
spadd_csr(&mut spmm_result, beta, alpha, op_a.as_ref()).unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
|
|
Loading…
Reference in New Issue