Improve dimension assertions for spadd_csr
This commit is contained in:
parent
2d534a6133
commit
d9cfe5cb3e
@ -90,14 +90,7 @@ pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
// TODO: Proper error messages
|
||||
if trans_a.to_bool() {
|
||||
assert_eq!(c.nrows(), a.ncols());
|
||||
assert_eq!(c.ncols(), a.nrows());
|
||||
} else {
|
||||
assert_eq!(c.nrows(), a.nrows());
|
||||
assert_eq!(c.ncols(), a.ncols());
|
||||
}
|
||||
assert_compatible_spadd_dims!(c, a, trans_a);
|
||||
|
||||
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
||||
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
|
||||
|
@ -30,6 +30,24 @@ macro_rules! assert_compatible_spmm_dims {
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_use]
|
||||
macro_rules! assert_compatible_spadd_dims {
|
||||
($c:expr, $a:expr, $trans_a:expr) => {
|
||||
use crate::ops::Transpose;
|
||||
match $trans_a {
|
||||
Transpose(false) => {
|
||||
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), $a.ncols(), "C.ncols() != A.ncols()");
|
||||
},
|
||||
Transpose(true) => {
|
||||
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), $a.nrows(), "C.ncols() != A.nrows()");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
mod csr;
|
||||
mod pattern;
|
||||
|
||||
|
@ -398,4 +398,31 @@ proptest! {
|
||||
prop_assert!(result.is_err(),
|
||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spadd_csr_panics_on_dim_mismatch(
|
||||
(alpha, beta, c, a, trans_a)
|
||||
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
csr_strategy(),
|
||||
csr_strategy(),
|
||||
trans_strategy())
|
||||
) {
|
||||
let op_a_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
||||
let op_a_cols = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
||||
|
||||
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
|
||||
|
||||
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||
// skip the test, so we assume that they are not.
|
||||
prop_assume!(!dims_are_compatible);
|
||||
|
||||
let result = catch_unwind(|| {
|
||||
let mut spmm_result = c.clone();
|
||||
spadd_csr(&mut spmm_result, beta, alpha, trans_a, &a).unwrap();
|
||||
});
|
||||
|
||||
prop_assert!(result.is_err(),
|
||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user