Improve dimension assertions for spadd_csr
This commit is contained in:
parent
2d534a6133
commit
d9cfe5cb3e
|
@ -90,14 +90,7 @@ pub fn spadd_csr<T>(c: &mut CsrMatrix<T>,
|
||||||
where
|
where
|
||||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||||
{
|
{
|
||||||
// TODO: Proper error messages
|
assert_compatible_spadd_dims!(c, a, trans_a);
|
||||||
if trans_a.to_bool() {
|
|
||||||
assert_eq!(c.nrows(), a.ncols());
|
|
||||||
assert_eq!(c.ncols(), a.nrows());
|
|
||||||
} else {
|
|
||||||
assert_eq!(c.nrows(), a.nrows());
|
|
||||||
assert_eq!(c.ncols(), a.ncols());
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
// TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc`
|
||||||
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
|
if Arc::ptr_eq(&c.pattern(), &a.pattern()) {
|
||||||
|
|
|
@ -30,6 +30,24 @@ macro_rules! assert_compatible_spmm_dims {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
macro_rules! assert_compatible_spadd_dims {
|
||||||
|
($c:expr, $a:expr, $trans_a:expr) => {
|
||||||
|
use crate::ops::Transpose;
|
||||||
|
match $trans_a {
|
||||||
|
Transpose(false) => {
|
||||||
|
assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()");
|
||||||
|
assert_eq!($c.ncols(), $a.ncols(), "C.ncols() != A.ncols()");
|
||||||
|
},
|
||||||
|
Transpose(true) => {
|
||||||
|
assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()");
|
||||||
|
assert_eq!($c.ncols(), $a.nrows(), "C.ncols() != A.nrows()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mod csr;
|
mod csr;
|
||||||
mod pattern;
|
mod pattern;
|
||||||
|
|
||||||
|
|
|
@ -398,4 +398,31 @@ proptest! {
|
||||||
prop_assert!(result.is_err(),
|
prop_assert!(result.is_err(),
|
||||||
"The SPMM kernel executed successfully despite mismatch dimensions");
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spadd_csr_panics_on_dim_mismatch(
|
||||||
|
(alpha, beta, c, a, trans_a)
|
||||||
|
in (PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
PROPTEST_I32_VALUE_STRATEGY,
|
||||||
|
csr_strategy(),
|
||||||
|
csr_strategy(),
|
||||||
|
trans_strategy())
|
||||||
|
) {
|
||||||
|
let op_a_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() };
|
||||||
|
let op_a_cols = if trans_a.to_bool() { a.nrows() } else { a.ncols() };
|
||||||
|
|
||||||
|
let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols;
|
||||||
|
|
||||||
|
// If the dimensions randomly happen to be compatible, then of course we need to
|
||||||
|
// skip the test, so we assume that they are not.
|
||||||
|
prop_assume!(!dims_are_compatible);
|
||||||
|
|
||||||
|
let result = catch_unwind(|| {
|
||||||
|
let mut spmm_result = c.clone();
|
||||||
|
spadd_csr(&mut spmm_result, beta, alpha, trans_a, &a).unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
prop_assert!(result.is_err(),
|
||||||
|
"The SPMM kernel executed successfully despite mismatch dimensions");
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue