Merge pull request #139 from mitchmindtree/master
Allow for non-consuming std operations on DMat. Added DMat multiplication test.
This commit is contained in:
commit
8dfb8ee7b9
|
@ -265,7 +265,35 @@ impl<N> IndexMut<(usize, usize)> for DMat<N> {
|
||||||
impl<N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<DMat<N>> for DMat<N> {
|
impl<N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<DMat<N>> for DMat<N> {
|
||||||
type Output = DMat<N>;
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
fn mul(self, right: DMat<N>) -> DMat<N> {
|
fn mul(self, right: DMat<N>) -> DMat<N> {
|
||||||
|
(&self) * (&right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<&'a DMat<N>> for DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul(self, right: &'a DMat<N>) -> DMat<N> {
|
||||||
|
(&self) * right
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<DMat<N>> for &'a DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul(self, right: DMat<N>) -> DMat<N> {
|
||||||
|
right * self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Mul<N, Output = N> + Add<N, Output = N> + Zero> Mul<&'a DMat<N>> for &'a DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul(self, right: &DMat<N>) -> DMat<N> {
|
||||||
assert!(self.ncols == right.nrows);
|
assert!(self.ncols == right.nrows);
|
||||||
|
|
||||||
let mut res = unsafe { DMat::new_uninitialized(self.nrows, right.ncols) };
|
let mut res = unsafe { DMat::new_uninitialized(self.nrows, right.ncols) };
|
||||||
|
@ -668,6 +696,24 @@ impl<N: Copy + Add<N, Output = N>> Add<DMat<N>> for DMat<N> {
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn add(self, right: DMat<N>) -> DMat<N> {
|
fn add(self, right: DMat<N>) -> DMat<N> {
|
||||||
|
self + (&right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Add<N, Output = N>> Add<DMat<N>> for &'a DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn add(self, right: DMat<N>) -> DMat<N> {
|
||||||
|
right + self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Add<N, Output = N>> Add<&'a DMat<N>> for DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn add(self, right: &'a DMat<N>) -> DMat<N> {
|
||||||
assert!(self.nrows == right.nrows && self.ncols == right.ncols,
|
assert!(self.nrows == right.nrows && self.ncols == right.ncols,
|
||||||
"Unable to add matrices with different dimensions.");
|
"Unable to add matrices with different dimensions.");
|
||||||
|
|
||||||
|
@ -701,6 +747,24 @@ impl<N: Copy + Sub<N, Output = N>> Sub<DMat<N>> for DMat<N> {
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sub(self, right: DMat<N>) -> DMat<N> {
|
fn sub(self, right: DMat<N>) -> DMat<N> {
|
||||||
|
self - (&right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Sub<N, Output = N>> Sub<DMat<N>> for &'a DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn sub(self, right: DMat<N>) -> DMat<N> {
|
||||||
|
right - self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, N: Copy + Sub<N, Output = N>> Sub<&'a DMat<N>> for DMat<N> {
|
||||||
|
type Output = DMat<N>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn sub(self, right: &'a DMat<N>) -> DMat<N> {
|
||||||
assert!(self.nrows == right.nrows && self.ncols == right.ncols,
|
assert!(self.nrows == right.nrows && self.ncols == right.ncols,
|
||||||
"Unable to subtract matrices with different dimensions.");
|
"Unable to subtract matrices with different dimensions.");
|
||||||
|
|
||||||
|
|
32
tests/mat.rs
32
tests/mat.rs
|
@ -322,6 +322,38 @@ fn test_dmat_addition() {
|
||||||
assert!((mat1 + mat2) == res);
|
assert!((mat1 + mat2) == res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dmat_multiplication() {
|
||||||
|
let mat1 = DMat::from_row_vec(
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
&[
|
||||||
|
1.0, 2.0,
|
||||||
|
3.0, 4.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
let mat2 = DMat::from_row_vec(
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
&[
|
||||||
|
10.0, 20.0,
|
||||||
|
30.0, 40.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
let res = DMat::from_row_vec(
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
&[
|
||||||
|
70.0, 100.0,
|
||||||
|
150.0, 220.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!((mat1 * mat2) == res);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_dmat_subtraction() {
|
fn test_dmat_subtraction() {
|
||||||
let mat1 = DMat::from_row_vec(
|
let mat1 = DMat::from_row_vec(
|
||||||
|
|
Loading…
Reference in New Issue