diff --git a/nalgebra-sparse/src/ops/serial/csr.rs b/nalgebra-sparse/src/ops/serial/csr.rs index 9f14d8f1..670711cd 100644 --- a/nalgebra-sparse/src/ops/serial/csr.rs +++ b/nalgebra-sparse/src/ops/serial/csr.rs @@ -90,14 +90,7 @@ pub fn spadd_csr(c: &mut CsrMatrix, where T: Scalar + ClosedAdd + ClosedMul + Zero + One { - // TODO: Proper error messages - if trans_a.to_bool() { - assert_eq!(c.nrows(), a.ncols()); - assert_eq!(c.ncols(), a.nrows()); - } else { - assert_eq!(c.nrows(), a.nrows()); - assert_eq!(c.ncols(), a.ncols()); - } + assert_compatible_spadd_dims!(c, a, trans_a); // TODO: Change CsrMatrix::pattern() to return `&Arc` instead of `Arc` if Arc::ptr_eq(&c.pattern(), &a.pattern()) { diff --git a/nalgebra-sparse/src/ops/serial/mod.rs b/nalgebra-sparse/src/ops/serial/mod.rs index 01c31d93..8ac22ac8 100644 --- a/nalgebra-sparse/src/ops/serial/mod.rs +++ b/nalgebra-sparse/src/ops/serial/mod.rs @@ -30,6 +30,24 @@ macro_rules! assert_compatible_spmm_dims { } } +#[macro_use] +macro_rules! assert_compatible_spadd_dims { + ($c:expr, $a:expr, $trans_a:expr) => { + use crate::ops::Transpose; + match $trans_a { + Transpose(false) => { + assert_eq!($c.nrows(), $a.nrows(), "C.nrows() != A.nrows()"); + assert_eq!($c.ncols(), $a.ncols(), "C.ncols() != A.ncols()"); + }, + Transpose(true) => { + assert_eq!($c.nrows(), $a.ncols(), "C.nrows() != A.ncols()"); + assert_eq!($c.ncols(), $a.nrows(), "C.ncols() != A.nrows()"); + } + } + + } +} + mod csr; mod pattern; diff --git a/nalgebra-sparse/tests/unit_tests/ops.rs b/nalgebra-sparse/tests/unit_tests/ops.rs index 60e6cbcd..c1df496f 100644 --- a/nalgebra-sparse/tests/unit_tests/ops.rs +++ b/nalgebra-sparse/tests/unit_tests/ops.rs @@ -398,4 +398,31 @@ proptest! { prop_assert!(result.is_err(), "The SPMM kernel executed successfully despite mismatch dimensions"); } + + #[test] + fn spadd_csr_panics_on_dim_mismatch( + (alpha, beta, c, a, trans_a) + in (PROPTEST_I32_VALUE_STRATEGY, + PROPTEST_I32_VALUE_STRATEGY, + csr_strategy(), + csr_strategy(), + trans_strategy()) + ) { + let op_a_rows = if trans_a.to_bool() { a.ncols() } else { a.nrows() }; + let op_a_cols = if trans_a.to_bool() { a.nrows() } else { a.ncols() }; + + let dims_are_compatible = c.nrows() == op_a_rows && c.ncols() == op_a_cols; + + // If the dimensions randomly happen to be compatible, then of course we need to + // skip the test, so we assume that they are not. + prop_assume!(!dims_are_compatible); + + let result = catch_unwind(|| { + let mut spmm_result = c.clone(); + spadd_csr(&mut spmm_result, beta, alpha, trans_a, &a).unwrap(); + }); + + prop_assert!(result.is_err(), + "The SPMM kernel executed successfully despite mismatch dimensions"); + } } \ No newline at end of file