diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 84c6c7f0..3c8b2c53 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -265,7 +265,35 @@ impl IndexMut<(usize, usize)> for DMat { impl + Add + Zero> Mul> for DMat { type Output = DMat; + #[inline] fn mul(self, right: DMat) -> DMat { + (&self) * (&right) + } +} + +impl<'a, N: Copy + Mul + Add + Zero> Mul<&'a DMat> for DMat { + type Output = DMat; + + #[inline] + fn mul(self, right: &'a DMat) -> DMat { + (&self) * right + } +} + +impl<'a, N: Copy + Mul + Add + Zero> Mul> for &'a DMat { + type Output = DMat; + + #[inline] + fn mul(self, right: DMat) -> DMat { + right * self + } +} + +impl<'a, N: Copy + Mul + Add + Zero> Mul<&'a DMat> for &'a DMat { + type Output = DMat; + + #[inline] + fn mul(self, right: &DMat) -> DMat { assert!(self.ncols == right.nrows); let mut res = unsafe { DMat::new_uninitialized(self.nrows, right.ncols) }; @@ -668,6 +696,15 @@ impl> Add> for DMat { #[inline] fn add(self, right: DMat) -> DMat { + self + (&right) + } +} + +impl<'a, N: Copy + Add> Add<&'a DMat> for DMat { + type Output = DMat; + + #[inline] + fn add(self, right: &'a DMat) -> DMat { assert!(self.nrows == right.nrows && self.ncols == right.ncols, "Unable to add matrices with different dimensions."); @@ -701,6 +738,15 @@ impl> Sub> for DMat { #[inline] fn sub(self, right: DMat) -> DMat { + self - (&right) + } +} + +impl<'a, N: Copy + Sub> Sub<&'a DMat> for DMat { + type Output = DMat; + + #[inline] + fn sub(self, right: &'a DMat) -> DMat { assert!(self.nrows == right.nrows && self.ncols == right.ncols, "Unable to subtract matrices with different dimensions."); diff --git a/tests/mat.rs b/tests/mat.rs index 95e2dd71..b4fc2161 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -322,6 +322,38 @@ fn test_dmat_addition() { assert!((mat1 + mat2) == res); } +#[test] +fn test_dmat_multiplication() { + let mat1 = DMat::from_row_vec( + 2, + 2, + &[ + 1.0, 2.0, + 3.0, 4.0 + ] + ); + + let mat2 = DMat::from_row_vec( + 2, + 2, + &[ + 10.0, 20.0, + 30.0, 40.0 + ] + ); + + let res = DMat::from_row_vec( + 2, + 2, + &[ + 70.0, 100.0, + 150.0, 220.0 + ] + ); + + assert!((mat1 * mat2) == res); +} + #[test] fn test_dmat_subtraction() { let mat1 = DMat::from_row_vec(