From 51381ff84d755f2401f747543c9a377aa90a3eba Mon Sep 17 00:00:00 2001 From: mitchmindtree Date: Sun, 21 Jun 2015 00:20:39 +1000 Subject: [PATCH 1/2] Allow for non-consuming std operations on DMat. Added DMat multiplication test. --- src/structs/dmat.rs | 46 +++++++++++++++++++++++++++++++++++++++++++++ tests/mat.rs | 32 +++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 84c6c7f0..3c8b2c53 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -265,7 +265,35 @@ impl IndexMut<(usize, usize)> for DMat { impl + Add + Zero> Mul> for DMat { type Output = DMat; + #[inline] fn mul(self, right: DMat) -> DMat { + (&self) * (&right) + } +} + +impl<'a, N: Copy + Mul + Add + Zero> Mul<&'a DMat> for DMat { + type Output = DMat; + + #[inline] + fn mul(self, right: &'a DMat) -> DMat { + (&self) * right + } +} + +impl<'a, N: Copy + Mul + Add + Zero> Mul> for &'a DMat { + type Output = DMat; + + #[inline] + fn mul(self, right: DMat) -> DMat { + right * self + } +} + +impl<'a, N: Copy + Mul + Add + Zero> Mul<&'a DMat> for &'a DMat { + type Output = DMat; + + #[inline] + fn mul(self, right: &DMat) -> DMat { assert!(self.ncols == right.nrows); let mut res = unsafe { DMat::new_uninitialized(self.nrows, right.ncols) }; @@ -668,6 +696,15 @@ impl> Add> for DMat { #[inline] fn add(self, right: DMat) -> DMat { + self + (&right) + } +} + +impl<'a, N: Copy + Add> Add<&'a DMat> for DMat { + type Output = DMat; + + #[inline] + fn add(self, right: &'a DMat) -> DMat { assert!(self.nrows == right.nrows && self.ncols == right.ncols, "Unable to add matrices with different dimensions."); @@ -701,6 +738,15 @@ impl> Sub> for DMat { #[inline] fn sub(self, right: DMat) -> DMat { + self - (&right) + } +} + +impl<'a, N: Copy + Sub> Sub<&'a DMat> for DMat { + type Output = DMat; + + #[inline] + fn sub(self, right: &'a DMat) -> DMat { assert!(self.nrows == right.nrows && self.ncols == right.ncols, "Unable to subtract matrices with different dimensions."); diff --git a/tests/mat.rs b/tests/mat.rs index 95e2dd71..b4fc2161 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -322,6 +322,38 @@ fn test_dmat_addition() { assert!((mat1 + mat2) == res); } +#[test] +fn test_dmat_multiplication() { + let mat1 = DMat::from_row_vec( + 2, + 2, + &[ + 1.0, 2.0, + 3.0, 4.0 + ] + ); + + let mat2 = DMat::from_row_vec( + 2, + 2, + &[ + 10.0, 20.0, + 30.0, 40.0 + ] + ); + + let res = DMat::from_row_vec( + 2, + 2, + &[ + 70.0, 100.0, + 150.0, 220.0 + ] + ); + + assert!((mat1 * mat2) == res); +} + #[test] fn test_dmat_subtraction() { let mat1 = DMat::from_row_vec( From 2efb30876e9f7f2b3c1e7c07223610097387ae2b Mon Sep 17 00:00:00 2001 From: mitchmindtree Date: Sun, 21 Jun 2015 01:08:23 +1000 Subject: [PATCH 2/2] Added missing ops implementations for DMat --- src/structs/dmat.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 3c8b2c53..b55227fb 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -700,6 +700,15 @@ impl> Add> for DMat { } } +impl<'a, N: Copy + Add> Add> for &'a DMat { + type Output = DMat; + + #[inline] + fn add(self, right: DMat) -> DMat { + right + self + } +} + impl<'a, N: Copy + Add> Add<&'a DMat> for DMat { type Output = DMat; @@ -742,6 +751,15 @@ impl> Sub> for DMat { } } +impl<'a, N: Copy + Sub> Sub> for &'a DMat { + type Output = DMat; + + #[inline] + fn sub(self, right: DMat) -> DMat { + right - self + } +} + impl<'a, N: Copy + Sub> Sub<&'a DMat> for DMat { type Output = DMat;