From 1195eadd1aaa581db8e9e80405426b414638a666 Mon Sep 17 00:00:00 2001 From: Yotam Ofek Date: Sun, 12 Nov 2023 07:14:52 +0000 Subject: [PATCH] Allow creating matrix iter with an owned view --- src/base/conversion.rs | 24 +++++++ src/base/iter.rs | 152 ++++++++++++++++++++++++++++++++--------- tests/core/matrix.rs | 148 +++++++++++++++++++++++++-------------- 3 files changed, 242 insertions(+), 82 deletions(-) diff --git a/src/base/conversion.rs b/src/base/conversion.rs index c29535ef..783e6d9e 100644 --- a/src/base/conversion.rs +++ b/src/base/conversion.rs @@ -98,6 +98,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorage> IntoIterator } } +impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator + for Matrix> +{ + type Item = &'a T; + type IntoIter = MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + MatrixIter::new_owned(self.data) + } +} + impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut> IntoIterator for &'a mut Matrix { @@ -110,6 +122,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut> IntoIterator } } +impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator + for Matrix> +{ + type Item = &'a mut T; + type IntoIter = MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + MatrixIterMut::new_owned_mut(self.data) + } +} + impl From<[T; D]> for SVector { #[inline] fn from(arr: [T; D]) -> Self { diff --git a/src/base/iter.rs b/src/base/iter.rs index c2b1f58a..5da43721 100644 --- a/src/base/iter.rs +++ b/src/base/iter.rs @@ -12,26 +12,29 @@ use std::mem; use crate::base::dimension::{Dim, U1}; use crate::base::storage::{RawStorage, RawStorageMut}; -use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar}; +use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar, ViewStorage, ViewStorageMut}; + +#[derive(Clone, Debug)] +struct RawIter { + ptr: Ptr, + inner_ptr: Ptr, + inner_end: Ptr, + size: usize, + strides: (RStride, CStride), + _phantoms: PhantomData<(fn() -> T, R, C)>, +} macro_rules! iterator { (struct $Name:ident for $Storage:ident.$ptr: ident -> $Ptr:ty, $Ref:ty, $SRef: ty, $($derives:ident),* $(,)?) => { - /// An iterator through a dense matrix with arbitrary strides matrix. - #[derive($($derives),*)] - pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage> { - ptr: $Ptr, - inner_ptr: $Ptr, - inner_end: $Ptr, - size: usize, // We can't use an end pointer here because a stride might be zero. - strides: (S::RStride, S::CStride), - _phantoms: PhantomData<($Ref, R, C, S)>, - } - // TODO: we need to specialize for the case where the matrix storage is owned (in which // case the iterator is trivial because it does not have any stride). - impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> $Name<'a, T, R, C, S> { + impl + RawIter<$Ptr, T, R, C, RStride, CStride> + { /// Creates a new iterator for the given matrix storage. - pub fn new(storage: $SRef) -> $Name<'a, T, R, C, S> { + fn new<'a, S: $Storage>( + storage: $SRef, + ) -> Self { let shape = storage.shape(); let strides = storage.strides(); let inner_offset = shape.0.value() * strides.0.value(); @@ -55,7 +58,7 @@ macro_rules! iterator { unsafe { ptr.add(inner_offset) } }; - $Name { + RawIter { ptr, inner_ptr: ptr, inner_end, @@ -66,11 +69,13 @@ macro_rules! iterator { } } - impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> Iterator for $Name<'a, T, R, C, S> { - type Item = $Ref; + impl Iterator + for RawIter<$Ptr, T, R, C, RStride, CStride> + { + type Item = $Ptr; #[inline] - fn next(&mut self) -> Option<$Ref> { + fn next(&mut self) -> Option { unsafe { if self.size == 0 { None @@ -102,10 +107,7 @@ macro_rules! iterator { self.ptr = self.ptr.add(stride); } - // We want either `& *last` or `&mut *last` here, depending - // on the mutability of `$Ref`. - #[allow(clippy::transmute_ptr_to_ref)] - Some(mem::transmute(old)) + Some(old) } } } @@ -121,11 +123,11 @@ macro_rules! iterator { } } - impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> DoubleEndedIterator - for $Name<'a, T, R, C, S> + impl DoubleEndedIterator + for RawIter<$Ptr, T, R, C, RStride, CStride> { #[inline] - fn next_back(&mut self) -> Option<$Ref> { + fn next_back(&mut self) -> Option { unsafe { if self.size == 0 { None @@ -152,21 +154,85 @@ macro_rules! iterator { .ptr .add((outer_remaining * outer_stride + inner_remaining * inner_stride)); - // We want either `& *last` or `&mut *last` here, depending - // on the mutability of `$Ref`. - #[allow(clippy::transmute_ptr_to_ref)] - Some(mem::transmute(last)) + Some(last) } } } } + impl ExactSizeIterator + for RawIter<$Ptr, T, R, C, RStride, CStride> + { + #[inline] + fn len(&self) -> usize { + self.size + } + } + + impl FusedIterator + for RawIter<$Ptr, T, R, C, RStride, CStride> + { + } + + /// An iterator through a dense matrix with arbitrary strides matrix. + #[derive($($derives),*)] + pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage> { + inner: RawIter<$Ptr, T, R, C, S::RStride, S::CStride>, + _marker: PhantomData<$Ref>, + } + + impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> $Name<'a, T, R, C, S> { + /// Creates a new iterator for the given matrix storage. + pub fn new(storage: $SRef) -> Self { + Self { + inner: RawIter::<$Ptr, T, R, C, S::RStride, S::CStride>::new(storage), + _marker: PhantomData, + } + } + } + + impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> Iterator for $Name<'a, T, R, C, S> { + type Item = $Ref; + + #[inline(always)] + fn next(&mut self) -> Option { + // We want either `& *last` or `&mut *last` here, depending + // on the mutability of `$Ref`. + #[allow(clippy::transmute_ptr_to_ref)] + self.inner.next().map(|ptr| unsafe { mem::transmute(ptr) }) + } + + #[inline(always)] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } + + #[inline(always)] + fn count(self) -> usize { + self.inner.count() + } + } + + impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> DoubleEndedIterator + for $Name<'a, T, R, C, S> + { + #[inline(always)] + fn next_back(&mut self) -> Option { + // We want either `& *last` or `&mut *last` here, depending + // on the mutability of `$Ref`. + #[allow(clippy::transmute_ptr_to_ref)] + self.inner + .next_back() + .map(|ptr| unsafe { mem::transmute(ptr) }) + } + } + impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage> ExactSizeIterator for $Name<'a, T, R, C, S> { - #[inline] + #[inline(always)] fn len(&self) -> usize { - self.size + self.inner.len() } } @@ -180,6 +246,30 @@ macro_rules! iterator { iterator!(struct MatrixIter for RawStorage.ptr -> *const T, &'a T, &'a S, Clone, Debug); iterator!(struct MatrixIterMut for RawStorageMut.ptr_mut -> *mut T, &'a mut T, &'a mut S, Debug); +impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> + MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>> +{ + /// Creates a new iterator for the given matrix storage view. + pub fn new_owned(storage: ViewStorage<'a, T, R, C, RStride, CStride>) -> Self { + Self { + inner: RawIter::<*const T, T, R, C, RStride, CStride>::new(&storage), + _marker: PhantomData, + } + } +} + +impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> + MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>> +{ + /// Creates a new iterator for the given matrix storage view. + pub fn new_owned_mut(mut storage: ViewStorageMut<'a, T, R, C, RStride, CStride>) -> Self { + Self { + inner: RawIter::<*mut T, T, R, C, RStride, CStride>::new(&mut storage), + _marker: PhantomData, + } + } +} + /* * * Row iterators. diff --git a/tests/core/matrix.rs b/tests/core/matrix.rs index 27926a27..501a0566 100644 --- a/tests/core/matrix.rs +++ b/tests/core/matrix.rs @@ -1,80 +1,126 @@ +use na::iter::MatrixIter; use num::{One, Zero}; use std::cmp::Ordering; use na::dimension::{U15, U8}; use na::{ self, Const, DMatrix, DVector, Matrix2, Matrix2x3, Matrix2x4, Matrix3, Matrix3x2, Matrix3x4, - Matrix4, Matrix4x3, Matrix4x5, Matrix5, Matrix6, OMatrix, RowVector3, RowVector4, RowVector5, - Vector1, Vector2, Vector3, Vector4, Vector5, Vector6, + Matrix4, Matrix4x3, Matrix4x5, Matrix5, Matrix6, MatrixView2x3, MatrixViewMut2x3, OMatrix, + RowVector3, RowVector4, RowVector5, Vector1, Vector2, Vector3, Vector4, Vector5, Vector6, }; #[test] fn iter() { let a = Matrix2x3::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0); + let view: MatrixView2x3<_> = (&a).into(); - let mut it = a.iter(); - assert_eq!(*it.next().unwrap(), 1.0); - assert_eq!(*it.next().unwrap(), 4.0); - assert_eq!(*it.next().unwrap(), 2.0); - assert_eq!(*it.next().unwrap(), 5.0); - assert_eq!(*it.next().unwrap(), 3.0); - assert_eq!(*it.next().unwrap(), 6.0); - assert!(it.next().is_none()); + fn test<'a, F: Fn() -> I, I: Iterator + DoubleEndedIterator>(it: F) { + { + let mut it = it(); + assert_eq!(*it.next().unwrap(), 1.0); + assert_eq!(*it.next().unwrap(), 4.0); + assert_eq!(*it.next().unwrap(), 2.0); + assert_eq!(*it.next().unwrap(), 5.0); + assert_eq!(*it.next().unwrap(), 3.0); + assert_eq!(*it.next().unwrap(), 6.0); + assert!(it.next().is_none()); + } - let mut it = a.iter(); - assert_eq!(*it.next().unwrap(), 1.0); - assert_eq!(*it.next_back().unwrap(), 6.0); - assert_eq!(*it.next_back().unwrap(), 3.0); - assert_eq!(*it.next_back().unwrap(), 5.0); - assert_eq!(*it.next().unwrap(), 4.0); - assert_eq!(*it.next().unwrap(), 2.0); - assert!(it.next().is_none()); + { + let mut it = it(); + assert_eq!(*it.next().unwrap(), 1.0); + assert_eq!(*it.next_back().unwrap(), 6.0); + assert_eq!(*it.next_back().unwrap(), 3.0); + assert_eq!(*it.next_back().unwrap(), 5.0); + assert_eq!(*it.next().unwrap(), 4.0); + assert_eq!(*it.next().unwrap(), 2.0); + assert!(it.next().is_none()); + } + { + let mut it = it().rev(); + assert_eq!(*it.next().unwrap(), 6.0); + assert_eq!(*it.next().unwrap(), 3.0); + assert_eq!(*it.next().unwrap(), 5.0); + assert_eq!(*it.next().unwrap(), 2.0); + assert_eq!(*it.next().unwrap(), 4.0); + assert_eq!(*it.next().unwrap(), 1.0); + assert!(it.next().is_none()); + } + } - let mut it = a.iter().rev(); - assert_eq!(*it.next().unwrap(), 6.0); - assert_eq!(*it.next().unwrap(), 3.0); - assert_eq!(*it.next().unwrap(), 5.0); - assert_eq!(*it.next().unwrap(), 2.0); - assert_eq!(*it.next().unwrap(), 4.0); - assert_eq!(*it.next().unwrap(), 1.0); - assert!(it.next().is_none()); + test(|| a.iter()); + test(|| view.into_iter()); let row = a.row(0); - let mut it = row.iter(); - assert_eq!(*it.next().unwrap(), 1.0); - assert_eq!(*it.next().unwrap(), 2.0); - assert_eq!(*it.next().unwrap(), 3.0); - assert!(it.next().is_none()); + let row_test = |mut it: MatrixIter<_, _, _, _>| { + assert_eq!(*it.next().unwrap(), 1.0); + assert_eq!(*it.next().unwrap(), 2.0); + assert_eq!(*it.next().unwrap(), 3.0); + assert!(it.next().is_none()); + }; + row_test(row.iter()); + row_test(row.into_iter()); let row = a.row(1); - let mut it = row.iter(); - assert_eq!(*it.next().unwrap(), 4.0); - assert_eq!(*it.next().unwrap(), 5.0); - assert_eq!(*it.next().unwrap(), 6.0); - assert!(it.next().is_none()); + let row_test = |mut it: MatrixIter<_, _, _, _>| { + assert_eq!(*it.next().unwrap(), 4.0); + assert_eq!(*it.next().unwrap(), 5.0); + assert_eq!(*it.next().unwrap(), 6.0); + assert!(it.next().is_none()); + }; + row_test(row.iter()); + row_test(row.into_iter()); let m22 = row.column(1); - let mut it = m22.iter(); - assert_eq!(*it.next().unwrap(), 5.0); - assert!(it.next().is_none()); + let m22_test = |mut it: MatrixIter<_, _, _, _>| { + assert_eq!(*it.next().unwrap(), 5.0); + assert!(it.next().is_none()); + }; + m22_test(m22.iter()); + m22_test(m22.into_iter()); let col = a.column(0); - let mut it = col.iter(); - assert_eq!(*it.next().unwrap(), 1.0); - assert_eq!(*it.next().unwrap(), 4.0); - assert!(it.next().is_none()); + let col_test = |mut it: MatrixIter<_, _, _, _>| { + assert_eq!(*it.next().unwrap(), 1.0); + assert_eq!(*it.next().unwrap(), 4.0); + assert!(it.next().is_none()); + }; + col_test(col.iter()); + col_test(col.into_iter()); let col = a.column(1); - let mut it = col.iter(); - assert_eq!(*it.next().unwrap(), 2.0); - assert_eq!(*it.next().unwrap(), 5.0); - assert!(it.next().is_none()); + let col_test = |mut it: MatrixIter<_, _, _, _>| { + assert_eq!(*it.next().unwrap(), 2.0); + assert_eq!(*it.next().unwrap(), 5.0); + assert!(it.next().is_none()); + }; + col_test(col.iter()); + col_test(col.into_iter()); let col = a.column(2); - let mut it = col.iter(); - assert_eq!(*it.next().unwrap(), 3.0); - assert_eq!(*it.next().unwrap(), 6.0); - assert!(it.next().is_none()); + let col_test = |mut it: MatrixIter<_, _, _, _>| { + assert_eq!(*it.next().unwrap(), 3.0); + assert_eq!(*it.next().unwrap(), 6.0); + assert!(it.next().is_none()); + }; + col_test(col.iter()); + col_test(col.into_iter()); +} + +#[test] +fn iter_mut() { + let mut a = Matrix2x3::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0); + + for v in a.iter_mut() { + *v *= 2.0; + } + assert_eq!(a, Matrix2x3::new(2.0, 4.0, 6.0, 8.0, 10.0, 12.0)); + + let view: MatrixViewMut2x3<_> = MatrixViewMut2x3::from(&mut a); + for v in view.into_iter() { + *v *= 2.0; + } + assert_eq!(a, Matrix2x3::new(4.0, 8.0, 12.0, 16.0, 20.0, 24.0)); } #[test]