diff --git a/src/core/ops.rs b/src/core/ops.rs index a530e066..98c0186f 100644 --- a/src/core/ops.rs +++ b/src/core/ops.rs @@ -1,11 +1,12 @@ +use std::iter; use std::ops::{Add, AddAssign, Sub, SubAssign, Mul, MulAssign, Div, DivAssign, Neg, Index, IndexMut}; -use num::Zero; +use num::{Zero, One}; use alga::general::{ClosedMul, ClosedDiv, ClosedAdd, ClosedSub, ClosedNeg}; -use core::{Scalar, Matrix, OwnedMatrix, MatrixSum, MatrixMul, MatrixTrMul}; -use core::dimension::{Dim, DimMul, DimProd}; +use core::{Scalar, Matrix, OwnedMatrix, SquareMatrix, MatrixSum, MatrixMul, MatrixTrMul}; +use core::dimension::{Dim, DimMul, DimName, DimProd}; use core::constraint::{ShapeConstraint, SameNumberOfRows, SameNumberOfColumns, AreMultipliable}; use core::storage::{Storage, StorageMut, OwnedStorage}; use core::allocator::{SameShapeAllocator, Allocator, OwnedAllocator}; @@ -231,6 +232,25 @@ macro_rules! componentwise_binop_impl( componentwise_binop_impl!(Add, add, ClosedAdd; AddAssign, add_assign); componentwise_binop_impl!(Sub, sub, ClosedSub; SubAssign, sub_assign); +impl iter::Sum for Matrix + where N: Scalar + ClosedAdd + Zero, + S: OwnedStorage, + S::Alloc: OwnedAllocator +{ + fn sum>>(iter: I) -> Matrix { + iter.fold(Matrix::zero(), |acc, x| acc + x) + } +} + +impl<'a, N, R: DimName, C: DimName, S> iter::Sum<&'a Matrix> for Matrix + where N: Scalar + ClosedAdd + Zero, + S: OwnedStorage, + S::Alloc: OwnedAllocator +{ + fn sum>>(iter: I) -> Matrix { + iter.fold(Matrix::zero(), |acc, x| acc + x) + } +} /* @@ -528,3 +548,23 @@ impl Matrix res } } + +impl iter::Product for SquareMatrix + where N: Scalar + Zero + One + ClosedMul + ClosedAdd, + S: OwnedStorage, + S::Alloc: OwnedAllocator +{ + fn product>>(iter: I) -> SquareMatrix { + iter.fold(Matrix::one(), |acc, x| acc * x) + } +} + +impl<'a, N, D: DimName, S> iter::Product<&'a SquareMatrix> for SquareMatrix + where N: Scalar + Zero + One + ClosedMul + ClosedAdd, + S: OwnedStorage, + S::Alloc: OwnedAllocator +{ + fn product>>(iter: I) -> SquareMatrix { + iter.fold(Matrix::one(), |acc, x| acc * x) + } +} diff --git a/tests/matrix.rs b/tests/matrix.rs index 2b319498..b41b1e34 100644 --- a/tests/matrix.rs +++ b/tests/matrix.rs @@ -263,6 +263,25 @@ fn simple_add() { assert_eq!(expected, &c + a); } +#[test] +fn simple_sum() { + type M = Matrix2x3; + + let a = M::new(1.0, 2.0, 3.0, + 4.0, 5.0, 6.0); + let b = M::new(10.0, 20.0, 30.0, + 40.0, 50.0, 60.0); + let c = M::new(100.0, 200.0, 300.0, + 400.0, 500.0, 600.0); + + assert_eq!(M::zero(), Vec::::new().iter().sum()); + assert_eq!(M::zero(), Vec::::new().into_iter().sum()); + assert_eq!(a + b, vec![a, b].iter().sum()); + assert_eq!(a + b, vec![a, b].into_iter().sum()); + assert_eq!(a + b + c, vec![a, b, c].iter().sum()); + assert_eq!(a + b + c, vec![a, b, c].into_iter().sum()); +} + #[test] fn simple_scalar_mul() { let a = Matrix2x3::new(1.0, 2.0, 3.0, @@ -295,6 +314,28 @@ fn simple_mul() { assert_eq!(expected, a * b); } +#[test] +fn simple_product() { + type M = Matrix3; + + let a = M::new(1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0); + let b = M::new(10.0, 20.0, 30.0, + 40.0, 50.0, 60.0, + 70.0, 80.0, 90.0); + let c = M::new(100.0, 200.0, 300.0, + 400.0, 500.0, 600.0, + 700.0, 800.0, 900.0); + + assert_eq!(M::one(), Vec::::new().iter().product()); + assert_eq!(M::one(), Vec::::new().into_iter().product()); + assert_eq!(a * b, vec![a, b].iter().product()); + assert_eq!(a * b, vec![a, b].into_iter().product()); + assert_eq!(a * b * c, vec![a, b, c].iter().product()); + assert_eq!(a * b * c, vec![a, b, c].into_iter().product()); +} + #[test] fn simple_scalar_conversion() { let a = Matrix2x3::new(1.0, 2.0, 3.0,