Add zip_apply and zip_zip_apply.

This commit is contained in:
sebcrozet 2018-12-09 16:56:09 +01:00
parent 904000ce27
commit d1391592a0

View File

@ -706,8 +706,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
/// Replaces each component of `self` by the result of a closure `f` applied on it. /// Replaces each component of `self` by the result of a closure `f` applied on it.
#[inline] #[inline]
pub fn apply<F: FnMut(N) -> N>(&mut self, mut f: F) pub fn apply<F: FnMut(N) -> N>(&mut self, mut f: F) {
where DefaultAllocator: Allocator<N, R, C> {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
for j in 0..ncols { for j in 0..ncols {
@ -719,6 +718,71 @@ impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
} }
} }
} }
/// Replaces each component of `self` by the result of a closure `f` applied on its components
/// joined with the components from `rhs`.
#[inline]
pub fn zip_apply<N2, R2, C2, S2>(&mut self, rhs: Matrix<N2, R2, C2, S2>, mut f: impl FnMut(N, N2) -> N)
where N2: Scalar,
R2: Dim,
C2: Dim,
S2: Storage<N2, R2, C2>,
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2> {
let (nrows, ncols) = self.shape();
assert!(
(nrows, ncols) == rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
for j in 0..ncols {
for i in 0..nrows {
unsafe {
let e = self.data.get_unchecked_mut(i, j);
let rhs = rhs.get_unchecked(i, j);
*e = f(*e, *rhs)
}
}
}
}
/// Replaces each component of `self` by the result of a closure `f` applied on its components
/// joined with the components from `b` and `c`.
#[inline]
pub fn zip_zip_apply<N2, R2, C2, S2, N3, R3, C3, S3>(&mut self, b: Matrix<N2, R2, C2, S2>, c: Matrix<N3, R3, C3, S3>, mut f: impl FnMut(N, N2, N3) -> N)
where N2: Scalar,
R2: Dim,
C2: Dim,
S2: Storage<N2, R2, C2>,
N3: Scalar,
R3: Dim,
C3: Dim,
S3: Storage<N3, R3, C3>,
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2>,
ShapeConstraint: SameNumberOfRows<R, R2> + SameNumberOfColumns<C, C2> {
let (nrows, ncols) = self.shape();
assert!(
(nrows, ncols) == b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert!(
(nrows, ncols) == c.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
for j in 0..ncols {
for i in 0..nrows {
unsafe {
let e = self.data.get_unchecked_mut(i, j);
let b = b.get_unchecked(i, j);
let c = c.get_unchecked(i, j);
*e = f(*e, *b, *c)
}
}
}
}
} }
impl<N: Scalar, D: Dim, S: Storage<N, D>> Vector<N, D, S> { impl<N: Scalar, D: Dim, S: Storage<N, D>> Vector<N, D, S> {