Add `try_[zip_[zip_]map` methods to Matrix

Useful for functions that can fail
- integer `_checked` methods
- fallible conversions
This commit is contained in:
Patiga 2022-08-18 02:06:34 +02:00
parent d09d06858f
commit 6ae273cfd7
1 changed files with 122 additions and 0 deletions

View File

@ -718,6 +718,35 @@ impl<T, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S> {
unsafe { res.assume_init() } unsafe { res.assume_init() }
} }
/// Returns a matrix containing the `Some` result of `f` applied to each of its entries.
/// If `f` returns `None` for any entry, this function returns `None`.
#[inline]
#[must_use]
pub fn try_map<T2: Scalar, F: FnMut(T) -> Option<T2>>(
&self,
mut f: F,
) -> Option<OMatrix<T2, R, C>>
where
T: Scalar,
DefaultAllocator: Allocator<T2, R, C>,
{
let (nrows, ncols) = self.shape_generic();
let mut res = Matrix::uninit(nrows, ncols);
for j in 0..ncols.value() {
for i in 0..nrows.value() {
// Safety: all indices are in range.
unsafe {
let a = self.data.get_unchecked(i, j).clone();
*res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a)?);
}
}
}
// Safety: res is now fully initialized.
Some(unsafe { res.assume_init() })
}
/// Cast the components of `self` to another type. /// Cast the components of `self` to another type.
/// ///
/// # Example /// # Example
@ -824,6 +853,48 @@ impl<T, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S> {
unsafe { res.assume_init() } unsafe { res.assume_init() }
} }
/// Returns a matrix containing the `Some` result of `f` applied to entries of `self`
/// and `rhs`.
/// If `f` returns `None` for any entry, this function returns `None`.
#[inline]
#[must_use]
pub fn try_zip_map<T2, N3, S2, F>(
&self,
rhs: &Matrix<T2, R, C, S2>,
mut f: F,
) -> Option<OMatrix<N3, R, C>>
where
T: Scalar,
T2: Scalar,
N3: Scalar,
S2: RawStorage<T2, R, C>,
F: FnMut(T, T2) -> Option<N3>,
DefaultAllocator: Allocator<N3, R, C>,
{
let (nrows, ncols) = self.shape_generic();
let mut res = Matrix::uninit(nrows, ncols);
assert_eq!(
(nrows.value(), ncols.value()),
rhs.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
for j in 0..ncols.value() {
for i in 0..nrows.value() {
// Safety: all indices are in range.
unsafe {
let a = self.data.get_unchecked(i, j).clone();
let b = rhs.data.get_unchecked(i, j).clone();
*res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a, b)?)
}
}
}
// Safety: res is now fully initialized.
Some(unsafe { res.assume_init() })
}
/// Returns a matrix containing the result of `f` applied to each entries of `self` and /// Returns a matrix containing the result of `f` applied to each entries of `self` and
/// `b`, and `c`. /// `b`, and `c`.
#[inline] #[inline]
@ -874,6 +945,57 @@ impl<T, R: Dim, C: Dim, S: RawStorage<T, R, C>> Matrix<T, R, C, S> {
unsafe { res.assume_init() } unsafe { res.assume_init() }
} }
/// Returns a matrix containing the `Some` result of `f` applied to entries of `self`
/// and `b`, and `c`.
/// If `f` returns `None` for any entry, this function returns `None`.
#[inline]
#[must_use]
pub fn try_zip_zip_map<T2, N3, N4, S2, S3, F>(
&self,
b: &Matrix<T2, R, C, S2>,
c: &Matrix<N3, R, C, S3>,
mut f: F,
) -> Option<OMatrix<N4, R, C>>
where
T: Scalar,
T2: Scalar,
N3: Scalar,
N4: Scalar,
S2: RawStorage<T2, R, C>,
S3: RawStorage<N3, R, C>,
F: FnMut(T, T2, N3) -> Option<N4>,
DefaultAllocator: Allocator<N4, R, C>,
{
let (nrows, ncols) = self.shape_generic();
let mut res = Matrix::uninit(nrows, ncols);
assert_eq!(
(nrows.value(), ncols.value()),
b.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
assert_eq!(
(nrows.value(), ncols.value()),
c.shape(),
"Matrix simultaneous traversal error: dimension mismatch."
);
for j in 0..ncols.value() {
for i in 0..nrows.value() {
// Safety: all indices are in range.
unsafe {
let a = self.data.get_unchecked(i, j).clone();
let b = b.data.get_unchecked(i, j).clone();
let c = c.data.get_unchecked(i, j).clone();
*res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a, b, c)?)
}
}
}
// Safety: res is now fully initialized.
Some(unsafe { res.assume_init() })
}
/// Folds a function `f` on each entry of `self`. /// Folds a function `f` on each entry of `self`.
#[inline] #[inline]
#[must_use] #[must_use]