From 6ae273cfd7bf34385e0b44f646f7df4e6dfc4a2c Mon Sep 17 00:00:00 2001 From: Patiga Date: Thu, 18 Aug 2022 02:06:34 +0200 Subject: [PATCH] Add `try_[zip_[zip_]map` methods to Matrix Useful for functions that can fail - integer `_checked` methods - fallible conversions --- src/base/matrix.rs | 122 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/src/base/matrix.rs b/src/base/matrix.rs index d9294c9e..1c99fc2a 100644 --- a/src/base/matrix.rs +++ b/src/base/matrix.rs @@ -718,6 +718,35 @@ impl> Matrix { unsafe { res.assume_init() } } + /// Returns a matrix containing the `Some` result of `f` applied to each of its entries. + /// If `f` returns `None` for any entry, this function returns `None`. + #[inline] + #[must_use] + pub fn try_map Option>( + &self, + mut f: F, + ) -> Option> + where + T: Scalar, + DefaultAllocator: Allocator, + { + let (nrows, ncols) = self.shape_generic(); + let mut res = Matrix::uninit(nrows, ncols); + + for j in 0..ncols.value() { + for i in 0..nrows.value() { + // Safety: all indices are in range. + unsafe { + let a = self.data.get_unchecked(i, j).clone(); + *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a)?); + } + } + } + + // Safety: res is now fully initialized. + Some(unsafe { res.assume_init() }) + } + /// Cast the components of `self` to another type. /// /// # Example @@ -824,6 +853,48 @@ impl> Matrix { unsafe { res.assume_init() } } + /// Returns a matrix containing the `Some` result of `f` applied to entries of `self` + /// and `rhs`. + /// If `f` returns `None` for any entry, this function returns `None`. + #[inline] + #[must_use] + pub fn try_zip_map( + &self, + rhs: &Matrix, + mut f: F, + ) -> Option> + where + T: Scalar, + T2: Scalar, + N3: Scalar, + S2: RawStorage, + F: FnMut(T, T2) -> Option, + DefaultAllocator: Allocator, + { + let (nrows, ncols) = self.shape_generic(); + let mut res = Matrix::uninit(nrows, ncols); + + assert_eq!( + (nrows.value(), ncols.value()), + rhs.shape(), + "Matrix simultaneous traversal error: dimension mismatch." + ); + + for j in 0..ncols.value() { + for i in 0..nrows.value() { + // Safety: all indices are in range. + unsafe { + let a = self.data.get_unchecked(i, j).clone(); + let b = rhs.data.get_unchecked(i, j).clone(); + *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a, b)?) + } + } + } + + // Safety: res is now fully initialized. + Some(unsafe { res.assume_init() }) + } + /// Returns a matrix containing the result of `f` applied to each entries of `self` and /// `b`, and `c`. #[inline] @@ -874,6 +945,57 @@ impl> Matrix { unsafe { res.assume_init() } } + /// Returns a matrix containing the `Some` result of `f` applied to entries of `self` + /// and `b`, and `c`. + /// If `f` returns `None` for any entry, this function returns `None`. + #[inline] + #[must_use] + pub fn try_zip_zip_map( + &self, + b: &Matrix, + c: &Matrix, + mut f: F, + ) -> Option> + where + T: Scalar, + T2: Scalar, + N3: Scalar, + N4: Scalar, + S2: RawStorage, + S3: RawStorage, + F: FnMut(T, T2, N3) -> Option, + DefaultAllocator: Allocator, + { + let (nrows, ncols) = self.shape_generic(); + let mut res = Matrix::uninit(nrows, ncols); + + assert_eq!( + (nrows.value(), ncols.value()), + b.shape(), + "Matrix simultaneous traversal error: dimension mismatch." + ); + assert_eq!( + (nrows.value(), ncols.value()), + c.shape(), + "Matrix simultaneous traversal error: dimension mismatch." + ); + + for j in 0..ncols.value() { + for i in 0..nrows.value() { + // Safety: all indices are in range. + unsafe { + let a = self.data.get_unchecked(i, j).clone(); + let b = b.data.get_unchecked(i, j).clone(); + let c = c.data.get_unchecked(i, j).clone(); + *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a, b, c)?) + } + } + } + + // Safety: res is now fully initialized. + Some(unsafe { res.assume_init() }) + } + /// Folds a function `f` on each entry of `self`. #[inline] #[must_use]