Merge pull request #267 from brendanzab/impl-sum-and-product-traits

Implement the sum and product traits for matrices
This commit is contained in:
Eduard Bopp 2017-07-13 03:06:16 +02:00 committed by GitHub
commit c9d1552966
2 changed files with 84 additions and 3 deletions

View File

@ -1,11 +1,12 @@
use std::iter;
use std::ops::{Add, AddAssign, Sub, SubAssign, Mul, MulAssign, Div, DivAssign, Neg, use std::ops::{Add, AddAssign, Sub, SubAssign, Mul, MulAssign, Div, DivAssign, Neg,
Index, IndexMut}; Index, IndexMut};
use num::Zero; use num::{Zero, One};
use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg}; use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg};
use core::{Scalar, Matrix, OwnedMatrix, MatrixSum, MatrixMul, MatrixTrMul}; use core::{Scalar, Matrix, OwnedMatrix, SquareMatrix, MatrixSum, MatrixMul, MatrixTrMul};
use core::dimension::{Dim, DimMul, DimProd}; use core::dimension::{Dim, DimMul, DimName, DimProd};
use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable};
use core::storage::{Storage, StorageMut, OwnedStorage}; use core::storage::{Storage, StorageMut, OwnedStorage};
use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator}; use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator};
@ -231,6 +232,25 @@ macro_rules! componentwise_binop_impl(
componentwise_binop_impl!(Add, add, ClosedAdd; AddAssign, add_assign); componentwise_binop_impl!(Add, add, ClosedAdd; AddAssign, add_assign);
componentwise_binop_impl!(Sub, sub, ClosedSub; SubAssign, sub_assign); componentwise_binop_impl!(Sub, sub, ClosedSub; SubAssign, sub_assign);
impl<N, R: DimName, C: DimName, S> iter::Sum for Matrix<N, R, C, S>
where N: Scalar + ClosedAdd + Zero,
S: OwnedStorage<N, R, C>,
S::Alloc: OwnedAllocator<N, R, C, S>
{
fn sum<I: Iterator<Item = Matrix<N, R, C, S>>>(iter: I) -> Matrix<N, R, C, S> {
iter.fold(Matrix::zero(), |acc, x| acc + x)
}
}
impl<'a, N, R: DimName, C: DimName, S> iter::Sum<&'a Matrix<N, R, C, S>> for Matrix<N, R, C, S>
where N: Scalar + ClosedAdd + Zero,
S: OwnedStorage<N, R, C>,
S::Alloc: OwnedAllocator<N, R, C, S>
{
fn sum<I: Iterator<Item = &'a Matrix<N, R, C, S>>>(iter: I) -> Matrix<N, R, C, S> {
iter.fold(Matrix::zero(), |acc, x| acc + x)
}
}
/* /*
@ -528,3 +548,23 @@ impl<N, R1: Dim, C1: Dim, SA> Matrix<N, R1, C1, SA>
res res
} }
} }
impl<N, D: DimName, S> iter::Product for SquareMatrix<N, D, S>
where N: Scalar + Zero + One + ClosedMul + ClosedAdd,
S: OwnedStorage<N, D, D>,
S::Alloc: OwnedAllocator<N, D, D, S>
{
fn product<I: Iterator<Item = SquareMatrix<N, D, S>>>(iter: I) -> SquareMatrix<N, D, S> {
iter.fold(Matrix::one(), |acc, x| acc * x)
}
}
impl<'a, N, D: DimName, S> iter::Product<&'a SquareMatrix<N, D, S>> for SquareMatrix<N, D, S>
where N: Scalar + Zero + One + ClosedMul + ClosedAdd,
S: OwnedStorage<N, D, D>,
S::Alloc: OwnedAllocator<N, D, D, S>
{
fn product<I: Iterator<Item = &'a SquareMatrix<N, D, S>>>(iter: I) -> SquareMatrix<N, D, S> {
iter.fold(Matrix::one(), |acc, x| acc * x)
}
}

View File

@ -263,6 +263,25 @@ fn simple_add() {
assert_eq!(expected, &c + a); assert_eq!(expected, &c + a);
} }
#[test]
fn simple_sum() {
type M = Matrix2x3<f32>;
let a = M::new(1.0, 2.0, 3.0,
4.0, 5.0, 6.0);
let b = M::new(10.0, 20.0, 30.0,
40.0, 50.0, 60.0);
let c = M::new(100.0, 200.0, 300.0,
400.0, 500.0, 600.0);
assert_eq!(M::zero(), Vec::<M>::new().iter().sum());
assert_eq!(M::zero(), Vec::<M>::new().into_iter().sum());
assert_eq!(a + b, vec![a, b].iter().sum());
assert_eq!(a + b, vec![a, b].into_iter().sum());
assert_eq!(a + b + c, vec![a, b, c].iter().sum());
assert_eq!(a + b + c, vec![a, b, c].into_iter().sum());
}
#[test] #[test]
fn simple_scalar_mul() { fn simple_scalar_mul() {
let a = Matrix2x3::new(1.0, 2.0, 3.0, let a = Matrix2x3::new(1.0, 2.0, 3.0,
@ -295,6 +314,28 @@ fn simple_mul() {
assert_eq!(expected, a * b); assert_eq!(expected, a * b);
} }
#[test]
fn simple_product() {
type M = Matrix3<f32>;
let a = M::new(1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0);
let b = M::new(10.0, 20.0, 30.0,
40.0, 50.0, 60.0,
70.0, 80.0, 90.0);
let c = M::new(100.0, 200.0, 300.0,
400.0, 500.0, 600.0,
700.0, 800.0, 900.0);
assert_eq!(M::one(), Vec::<M>::new().iter().product());
assert_eq!(M::one(), Vec::<M>::new().into_iter().product());
assert_eq!(a * b, vec![a, b].iter().product());
assert_eq!(a * b, vec![a, b].into_iter().product());
assert_eq!(a * b * c, vec![a, b, c].iter().product());
assert_eq!(a * b * c, vec![a, b, c].into_iter().product());
}
#[test] #[test]
fn simple_scalar_conversion() { fn simple_scalar_conversion() {
let a = Matrix2x3::new(1.0, 2.0, 3.0, let a = Matrix2x3::new(1.0, 2.0, 3.0,