diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index b55227fb..9a2bc2ca 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -285,11 +285,11 @@ impl<'a, N: Copy + Mul + Add + Zero> Mul> #[inline] fn mul(self, right: DMat) -> DMat { - right * self + self * (&right) } } -impl<'a, N: Copy + Mul + Add + Zero> Mul<&'a DMat> for &'a DMat { +impl<'a, 'b, N: Copy + Mul + Add + Zero> Mul<&'b DMat> for &'a DMat { type Output = DMat; #[inline] diff --git a/tests/mat.rs b/tests/mat.rs index b4fc2161..88f98e0d 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -354,6 +354,40 @@ fn test_dmat_multiplication() { assert!((mat1 * mat2) == res); } +// Tests multiplication of rectangular (non-square) matrices. +#[test] +fn test_dmat_multiplication_rect() { + let mat1 = DMat::from_row_vec( + 1, + 2, + &[ + 1.0, 2.0, + ] + ); + + let mat2 = DMat::from_row_vec( + 2, + 3, + &[ + 3.0, 4.0, 5.0, + 6.0, 7.0, 8.0, + ] + ); + + let res = DMat::from_row_vec( + 1, + 3, + &[ + 15.0, 18.0, 21.0, + ] + ); + + assert!((mat1.clone() * mat2.clone()) == res); + assert!((&mat1 * mat2.clone()) == res); + assert!((mat1.clone() * &mat2) == res); + assert!((&mat1 * &mat2) == res); +} + #[test] fn test_dmat_subtraction() { let mat1 = DMat::from_row_vec(