diff --git a/src/structs/dmatrix_macros.rs b/src/structs/dmatrix_macros.rs index 346d40f2..a50e23c6 100644 --- a/src/structs/dmatrix_macros.rs +++ b/src/structs/dmatrix_macros.rs @@ -516,7 +516,16 @@ macro_rules! dmat_impl( #[inline] fn sub(self, right: $dmatrix) -> $dmatrix { - right - self + assert!(self.nrows == right.nrows && self.ncols == right.ncols, + "Unable to subtract matrices with different dimensions."); + + let mut res = right; + + for (mij, res) in self.mij.iter().zip(res.mij.iter_mut()) { + *res = *mij - *res; + } + + res } } diff --git a/tests/mat.rs b/tests/mat.rs index 6df6fbcc..8d3469a3 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -439,7 +439,9 @@ fn test_dmat_addition() { ] ); - assert!((mat1 + mat2) == res); + assert!((mat1.clone() + mat2.clone()) == res); + assert!((mat1.clone() + &mat2) == res); + assert!((&mat1 + mat2) == res); } #[test] @@ -471,7 +473,10 @@ fn test_dmat_multiplication() { ] ); - assert!((mat1 * mat2) == res); + assert!((mat1.clone() * mat2.clone()) == res); + assert!((&mat1 * mat2.clone()) == res); + assert!((&mat1 * &mat2) == res); + assert!((mat1 * &mat2) == res); } // Tests multiplication of rectangular (non-square) matrices. @@ -537,7 +542,9 @@ fn test_dmat_subtraction() { ] ); - assert!((mat1 - mat2) == res); + assert!((mat1.clone() - mat2.clone()) == res); + assert!((&mat1 - mat2.clone()) == res); + assert!((mat1 - &mat2) == res); } #[test]