Improve dimension assertions for spadd_csr

This commit is contained in:
Andreas Longva 2020-12-16 14:21:35 +01:00
parent 2d534a6133
commit d9cfe5cb3e
3 changed files with 46 additions and 8 deletions

View File

@ -90,14 +90,7 @@ pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
where where
T: Scalar + ClosedAdd + ClosedMul + Zero + One T: Scalar + ClosedAdd + ClosedMul + Zero + One
{ {
// TODO: Proper error messages assert_compatible_spadd_dims!(c, a, trans_a);
if trans_a.to_bool() {
assert_eq!(c.nrows(), a.ncols());
assert_eq!(c.ncols(), a.nrows());
} else {
assert_eq!(c.nrows(), a.nrows());
assert_eq!(c.ncols(), a.ncols());
}
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc` // TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
if Arc::ptr_eq(&c.pattern(), &a.pattern()) { if Arc::ptr_eq(&c.pattern(), &a.pattern()) {

View File

@ -30,6 +30,24 @@ macro_rules! assert_compatible_spmm_dims {
} }
} }
#[macro_use]
macro_rules! assert_compatible_spadd_dims {
($c:expr, $a:expr, $trans_a:expr) => {
use crate::ops::Transpose;
match $trans_a {
Transpose(false) => {
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), $a.ncols(), "C.ncols() != A.ncols()");
},
Transpose(true) => {
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), $a.nrows(), "C.ncols() != A.nrows()");
}
}
}
}
mod csr; mod csr;
mod pattern; mod pattern;

View File

@ -398,4 +398,31 @@ proptest! {
prop_assert!(result.is_err(), prop_assert!(result.is_err(),
"The SPMM kernel executed successfully despite mismatch dimensions"); "The SPMM kernel executed successfully despite mismatch dimensions");
} }
#[test]
fn spadd_csr_panics_on_dim_mismatch(
(alpha, beta, c, a, trans_a)
in (PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_I32_VALUE_STRATEGY,
csr_strategy(),
csr_strategy(),
trans_strategy())
) {
let op_a_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
let op_a_cols = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
// If the dimensions randomly happen to be compatible, then of course we need to
// skip the test, so we assume that they are not.
prop_assume!(!dims_are_compatible);
let result = catch_unwind(|| {
let mut spmm_result = c.clone();
spadd_csr(&mut spmm_result, beta, alpha, trans_a, &a).unwrap();
});
prop_assert!(result.is_err(),
"The SPMM kernel executed successfully despite mismatch dimensions");
}
} }