Allow for non-consuming std operations on DMat. Added DMat multiplication test.

This commit is contained in:
mitchmindtree 2015-06-21 00:20:39 +10:00
parent 792d7fda7a
commit 51381ff84d
2 changed files with 78 additions and 0 deletions

View File

@ -265,7 +265,35 @@ impl<N> IndexMut<(usize, usize)> for DMat<N> {
impl<N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<DMat<N>> for DMat<N> { impl<N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<DMat<N>> for DMat<N> {
type Output = DMat<N>; type Output = DMat<N>;
#[inline]
fn mul(self, right: DMat<N>) -> DMat<N> { fn mul(self, right: DMat<N>) -> DMat<N> {
(&self) * (&right)
}
}
impl<'a, N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<&'a DMat<N>> for DMat<N> {
type Output = DMat<N>;
#[inline]
fn mul(self, right: &'a DMat<N>) -> DMat<N> {
(&self) * right
}
}
impl<'a, N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<DMat<N>> for &'a DMat<N> {
type Output = DMat<N>;
#[inline]
fn mul(self, right: DMat<N>) -> DMat<N> {
right * self
}
}
impl<'a, N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<&'a DMat<N>> for &'a DMat<N> {
type Output = DMat<N>;
#[inline]
fn mul(self, right: &DMat<N>) -> DMat<N> {
assert!(self.ncols == right.nrows); assert!(self.ncols == right.nrows);
let mut res = unsafe { DMat::new_uninitialized(self.nrows, right.ncols) }; let mut res = unsafe { DMat::new_uninitialized(self.nrows, right.ncols) };
@ -668,6 +696,15 @@ impl<N: Copy + Add<N, Output = N>> Add<DMat<N>> for DMat<N> {
#[inline] #[inline]
fn add(self, right: DMat<N>) -> DMat<N> { fn add(self, right: DMat<N>) -> DMat<N> {
self + (&right)
}
}
impl<'a, N: Copy + Add<N, Output = N>> Add<&'a DMat<N>> for DMat<N> {
type Output = DMat<N>;
#[inline]
fn add(self, right: &'a DMat<N>) -> DMat<N> {
assert!(self.nrows == right.nrows && self.ncols == right.ncols, assert!(self.nrows == right.nrows && self.ncols == right.ncols,
"Unable to add matrices with different dimensions."); "Unable to add matrices with different dimensions.");
@ -701,6 +738,15 @@ impl<N: Copy + Sub<N, Output = N>> Sub<DMat<N>> for DMat<N> {
#[inline] #[inline]
fn sub(self, right: DMat<N>) -> DMat<N> { fn sub(self, right: DMat<N>) -> DMat<N> {
self - (&right)
}
}
impl<'a, N: Copy + Sub<N, Output = N>> Sub<&'a DMat<N>> for DMat<N> {
type Output = DMat<N>;
#[inline]
fn sub(self, right: &'a DMat<N>) -> DMat<N> {
assert!(self.nrows == right.nrows && self.ncols == right.ncols, assert!(self.nrows == right.nrows && self.ncols == right.ncols,
"Unable to subtract matrices with different dimensions."); "Unable to subtract matrices with different dimensions.");

View File

@ -322,6 +322,38 @@ fn test_dmat_addition() {
assert!((mat1 + mat2) == res); assert!((mat1 + mat2) == res);
} }
#[test]
fn test_dmat_multiplication() {
let mat1 = DMat::from_row_vec(
2,
2,
&[
1.0, 2.0,
3.0, 4.0
]
);
let mat2 = DMat::from_row_vec(
2,
2,
&[
10.0, 20.0,
30.0, 40.0
]
);
let res = DMat::from_row_vec(
2,
2,
&[
70.0, 100.0,
150.0, 220.0
]
);
assert!((mat1 * mat2) == res);
}
#[test] #[test]
fn test_dmat_subtraction() { fn test_dmat_subtraction() {
let mat1 = DMat::from_row_vec( let mat1 = DMat::from_row_vec(