Add pointwise addition and subtraction for `DMat`.

Fix #132.
This commit is contained in:
Sébastien Crozet 2015-06-06 12:53:40 +02:00
parent 981bc85e2a
commit 410c3c9566
2 changed files with 100 additions and 0 deletions

View File

@ -663,6 +663,24 @@ impl<N: Copy + Add<N, Output = N>> Add<N> for DMat<N> {
} }
} }
impl<N: Copy + Add<N, Output = N>> Add<DMat<N>> for DMat<N> {
type Output = DMat<N>;
#[inline]
fn add(self, right: DMat<N>) -> DMat<N> {
assert!(self.nrows == right.nrows && self.ncols == right.ncols,
"Unable to add matrices with different dimensions.");
let mut res = self;
for (mij, right_ij) in res.mij.iter_mut().zip(right.mij.iter()) {
*mij = *mij + *right_ij;
}
res
}
}
impl<N: Copy + Sub<N, Output = N>> Sub<N> for DMat<N> { impl<N: Copy + Sub<N, Output = N>> Sub<N> for DMat<N> {
type Output = DMat<N>; type Output = DMat<N>;
@ -678,6 +696,24 @@ impl<N: Copy + Sub<N, Output = N>> Sub<N> for DMat<N> {
} }
} }
impl<N: Copy + Sub<N, Output = N>> Sub<DMat<N>> for DMat<N> {
type Output = DMat<N>;
#[inline]
fn sub(self, right: DMat<N>) -> DMat<N> {
assert!(self.nrows == right.nrows && self.ncols == right.ncols,
"Unable to subtract matrices with different dimensions.");
let mut res = self;
for (mij, right_ij) in res.mij.iter_mut().zip(right.mij.iter()) {
*mij = *mij - *right_ij;
}
res
}
}
#[cfg(feature="arbitrary")] #[cfg(feature="arbitrary")]
impl<N: Arbitrary> Arbitrary for DMat<N> { impl<N: Arbitrary> Arbitrary for DMat<N> {
fn arbitrary<G: Gen>(g: &mut G) -> DMat<N> { fn arbitrary<G: Gen>(g: &mut G) -> DMat<N> {

View File

@ -290,6 +290,70 @@ fn test_dmat_from_vec() {
assert!(mat1 == mat2); assert!(mat1 == mat2);
} }
#[test]
fn test_dmat_addition() {
let mat1 = DMat::from_row_vec(
2,
2,
&[
1.0, 2.0,
3.0, 4.0
]
);
let mat2 = DMat::from_row_vec(
2,
2,
&[
10.0, 20.0,
30.0, 40.0
]
);
let res = DMat::from_row_vec(
2,
2,
&[
11.0, 22.0,
33.0, 44.0
]
);
assert!((mat1 + mat2) == res);
}
#[test]
fn test_dmat_subtraction() {
let mat1 = DMat::from_row_vec(
2,
2,
&[
1.0, 2.0,
3.0, 4.0
]
);
let mat2 = DMat::from_row_vec(
2,
2,
&[
10.0, 20.0,
30.0, 40.0
]
);
let res = DMat::from_row_vec(
2,
2,
&[
-09.0, -18.0,
-27.0, -36.0
]
);
assert!((mat1 - mat2) == res);
}
/* FIXME: review qr decomposition to make it work with DMat. /* FIXME: review qr decomposition to make it work with DMat.
#[test] #[test]
fn test_qr() { fn test_qr() {