From 410c3c9566edc2e1bb66a6a7904b2a60fbd06758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Crozet?= Date: Sat, 6 Jun 2015 12:53:40 +0200 Subject: [PATCH] Add pointwise addition and subtraction for `DMat`. Fix #132. --- src/structs/dmat.rs | 36 +++++++++++++++++++++++++ tests/mat.rs | 64 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/src/structs/dmat.rs b/src/structs/dmat.rs index 6d23b842..84c6c7f0 100644 --- a/src/structs/dmat.rs +++ b/src/structs/dmat.rs @@ -663,6 +663,24 @@ impl> Add for DMat { } } +impl> Add> for DMat { + type Output = DMat; + + #[inline] + fn add(self, right: DMat) -> DMat { + assert!(self.nrows == right.nrows && self.ncols == right.ncols, + "Unable to add matrices with different dimensions."); + + let mut res = self; + + for (mij, right_ij) in res.mij.iter_mut().zip(right.mij.iter()) { + *mij = *mij + *right_ij; + } + + res + } +} + impl> Sub for DMat { type Output = DMat; @@ -678,6 +696,24 @@ impl> Sub for DMat { } } +impl> Sub> for DMat { + type Output = DMat; + + #[inline] + fn sub(self, right: DMat) -> DMat { + assert!(self.nrows == right.nrows && self.ncols == right.ncols, + "Unable to subtract matrices with different dimensions."); + + let mut res = self; + + for (mij, right_ij) in res.mij.iter_mut().zip(right.mij.iter()) { + *mij = *mij - *right_ij; + } + + res + } +} + #[cfg(feature="arbitrary")] impl Arbitrary for DMat { fn arbitrary(g: &mut G) -> DMat { diff --git a/tests/mat.rs b/tests/mat.rs index 311b8906..95e2dd71 100644 --- a/tests/mat.rs +++ b/tests/mat.rs @@ -290,6 +290,70 @@ fn test_dmat_from_vec() { assert!(mat1 == mat2); } +#[test] +fn test_dmat_addition() { + let mat1 = DMat::from_row_vec( + 2, + 2, + &[ + 1.0, 2.0, + 3.0, 4.0 + ] + ); + + let mat2 = DMat::from_row_vec( + 2, + 2, + &[ + 10.0, 20.0, + 30.0, 40.0 + ] + ); + + let res = DMat::from_row_vec( + 2, + 2, + &[ + 11.0, 22.0, + 33.0, 44.0 + ] + ); + + assert!((mat1 + mat2) == res); +} + +#[test] +fn test_dmat_subtraction() { + let mat1 = DMat::from_row_vec( + 2, + 2, + &[ + 1.0, 2.0, + 3.0, 4.0 + ] + ); + + let mat2 = DMat::from_row_vec( + 2, + 2, + &[ + 10.0, 20.0, + 30.0, 40.0 + ] + ); + + let res = DMat::from_row_vec( + 2, + 2, + &[ + -09.0, -18.0, + -27.0, -36.0 + ] + ); + + assert!((mat1 - mat2) == res); +} + /* FIXME: review qr decomposition to make it work with DMat. #[test] fn test_qr() {