Merge pull request #1315 from yotamofek/owned-view-iter

Allow creating matrix iter with an owned view
This commit is contained in:
Sébastien Crozet 2023-11-12 23:29:41 +01:00 committed by GitHub
commit a91e3b0d89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 242 additions and 82 deletions

View File

@ -98,6 +98,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorage<T, R, C>> IntoIterator
}
}
impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator
for Matrix<T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
{
type Item = &'a T;
type IntoIter = MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
MatrixIter::new_owned(self.data)
}
}
impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> IntoIterator
for &'a mut Matrix<T, R, C, S>
{
@ -110,6 +122,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> IntoIterator
}
}
impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator
for Matrix<T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
{
type Item = &'a mut T;
type IntoIter = MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
MatrixIterMut::new_owned_mut(self.data)
}
}
impl<T: Scalar, const D: usize> From<[T; D]> for SVector<T, D> {
#[inline]
fn from(arr: [T; D]) -> Self {

View File

@ -12,26 +12,29 @@ use std::mem;
use crate::base::dimension::{Dim, U1};
use crate::base::storage::{RawStorage, RawStorageMut};
use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar};
use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar, ViewStorage, ViewStorageMut};
#[derive(Clone, Debug)]
struct RawIter<Ptr, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> {
ptr: Ptr,
inner_ptr: Ptr,
inner_end: Ptr,
size: usize,
strides: (RStride, CStride),
_phantoms: PhantomData<(fn() -> T, R, C)>,
}
macro_rules! iterator {
(struct $Name:ident for $Storage:ident.$ptr: ident -> $Ptr:ty, $Ref:ty, $SRef: ty, $($derives:ident),* $(,)?) => {
/// An iterator through a dense matrix with arbitrary strides matrix.
#[derive($($derives),*)]
pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
ptr: $Ptr,
inner_ptr: $Ptr,
inner_end: $Ptr,
size: usize, // We can't use an end pointer here because a stride might be zero.
strides: (S::RStride, S::CStride),
_phantoms: PhantomData<($Ref, R, C, S)>,
}
// TODO: we need to specialize for the case where the matrix storage is owned (in which
// case the iterator is trivial because it does not have any stride).
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
RawIter<$Ptr, T, R, C, RStride, CStride>
{
/// Creates a new iterator for the given matrix storage.
pub fn new(storage: $SRef) -> $Name<'a, T, R, C, S> {
fn new<'a, S: $Storage<T, R, C, RStride = RStride, CStride = CStride>>(
storage: $SRef,
) -> Self {
let shape = storage.shape();
let strides = storage.strides();
let inner_offset = shape.0.value() * strides.0.value();
@ -55,7 +58,7 @@ macro_rules! iterator {
unsafe { ptr.add(inner_offset) }
};
$Name {
RawIter {
ptr,
inner_ptr: ptr,
inner_end,
@ -66,11 +69,13 @@ macro_rules! iterator {
}
}
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
type Item = $Ref;
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Iterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
type Item = $Ptr;
#[inline]
fn next(&mut self) -> Option<$Ref> {
fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.size == 0 {
None
@ -102,10 +107,7 @@ macro_rules! iterator {
self.ptr = self.ptr.add(stride);
}
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
Some(mem::transmute(old))
Some(old)
}
}
}
@ -121,11 +123,11 @@ macro_rules! iterator {
}
}
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
for $Name<'a, T, R, C, S>
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> DoubleEndedIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
#[inline]
fn next_back(&mut self) -> Option<$Ref> {
fn next_back(&mut self) -> Option<Self::Item> {
unsafe {
if self.size == 0 {
None
@ -152,21 +154,85 @@ macro_rules! iterator {
.ptr
.add((outer_remaining * outer_stride + inner_remaining * inner_stride));
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
Some(mem::transmute(last))
Some(last)
}
}
}
}
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> ExactSizeIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
#[inline]
fn len(&self) -> usize {
self.size
}
}
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> FusedIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
}
/// An iterator through a dense matrix with arbitrary strides matrix.
#[derive($($derives),*)]
pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
inner: RawIter<$Ptr, T, R, C, S::RStride, S::CStride>,
_marker: PhantomData<$Ref>,
}
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
/// Creates a new iterator for the given matrix storage.
pub fn new(storage: $SRef) -> Self {
Self {
inner: RawIter::<$Ptr, T, R, C, S::RStride, S::CStride>::new(storage),
_marker: PhantomData,
}
}
}
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
type Item = $Ref;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
self.inner.next().map(|ptr| unsafe { mem::transmute(ptr) })
}
#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
#[inline(always)]
fn count(self) -> usize {
self.inner.count()
}
}
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
for $Name<'a, T, R, C, S>
{
#[inline(always)]
fn next_back(&mut self) -> Option<Self::Item> {
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
self.inner
.next_back()
.map(|ptr| unsafe { mem::transmute(ptr) })
}
}
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
for $Name<'a, T, R, C, S>
{
#[inline]
#[inline(always)]
fn len(&self) -> usize {
self.size
self.inner.len()
}
}
@ -180,6 +246,30 @@ macro_rules! iterator {
iterator!(struct MatrixIter for RawStorage.ptr -> *const T, &'a T, &'a S, Clone, Debug);
iterator!(struct MatrixIterMut for RawStorageMut.ptr_mut -> *mut T, &'a mut T, &'a mut S, Debug);
impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
{
/// Creates a new iterator for the given matrix storage view.
pub fn new_owned(storage: ViewStorage<'a, T, R, C, RStride, CStride>) -> Self {
Self {
inner: RawIter::<*const T, T, R, C, RStride, CStride>::new(&storage),
_marker: PhantomData,
}
}
}
impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
{
/// Creates a new iterator for the given matrix storage view.
pub fn new_owned_mut(mut storage: ViewStorageMut<'a, T, R, C, RStride, CStride>) -> Self {
Self {
inner: RawIter::<*mut T, T, R, C, RStride, CStride>::new(&mut storage),
_marker: PhantomData,
}
}
}
/*
*
* Row iterators.

View File

@ -1,80 +1,126 @@
use na::iter::MatrixIter;
use num::{One, Zero};
use std::cmp::Ordering;
use na::dimension::{U15, U8};
use na::{
self, Const, DMatrix, DVector, Matrix2, Matrix2x3, Matrix2x4, Matrix3, Matrix3x2, Matrix3x4,
Matrix4, Matrix4x3, Matrix4x5, Matrix5, Matrix6, OMatrix, RowVector3, RowVector4, RowVector5,
Vector1, Vector2, Vector3, Vector4, Vector5, Vector6,
Matrix4, Matrix4x3, Matrix4x5, Matrix5, Matrix6, MatrixView2x3, MatrixViewMut2x3, OMatrix,
RowVector3, RowVector4, RowVector5, Vector1, Vector2, Vector3, Vector4, Vector5, Vector6,
};
#[test]
fn iter() {
let a = Matrix2x3::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
let view: MatrixView2x3<_> = (&a).into();
let mut it = a.iter();
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 3.0);
assert_eq!(*it.next().unwrap(), 6.0);
assert!(it.next().is_none());
fn test<'a, F: Fn() -> I, I: Iterator<Item = &'a f64> + DoubleEndedIterator>(it: F) {
{
let mut it = it();
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 3.0);
assert_eq!(*it.next().unwrap(), 6.0);
assert!(it.next().is_none());
}
let mut it = a.iter();
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next_back().unwrap(), 6.0);
assert_eq!(*it.next_back().unwrap(), 3.0);
assert_eq!(*it.next_back().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert!(it.next().is_none());
{
let mut it = it();
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next_back().unwrap(), 6.0);
assert_eq!(*it.next_back().unwrap(), 3.0);
assert_eq!(*it.next_back().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert!(it.next().is_none());
}
{
let mut it = it().rev();
assert_eq!(*it.next().unwrap(), 6.0);
assert_eq!(*it.next().unwrap(), 3.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 1.0);
assert!(it.next().is_none());
}
}
let mut it = a.iter().rev();
assert_eq!(*it.next().unwrap(), 6.0);
assert_eq!(*it.next().unwrap(), 3.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 1.0);
assert!(it.next().is_none());
test(|| a.iter());
test(|| view.into_iter());
let row = a.row(0);
let mut it = row.iter();
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 3.0);
assert!(it.next().is_none());
let row_test = |mut it: MatrixIter<_, _, _, _>| {
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 3.0);
assert!(it.next().is_none());
};
row_test(row.iter());
row_test(row.into_iter());
let row = a.row(1);
let mut it = row.iter();
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 6.0);
assert!(it.next().is_none());
let row_test = |mut it: MatrixIter<_, _, _, _>| {
assert_eq!(*it.next().unwrap(), 4.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert_eq!(*it.next().unwrap(), 6.0);
assert!(it.next().is_none());
};
row_test(row.iter());
row_test(row.into_iter());
let m22 = row.column(1);
let mut it = m22.iter();
assert_eq!(*it.next().unwrap(), 5.0);
assert!(it.next().is_none());
let m22_test = |mut it: MatrixIter<_, _, _, _>| {
assert_eq!(*it.next().unwrap(), 5.0);
assert!(it.next().is_none());
};
m22_test(m22.iter());
m22_test(m22.into_iter());
let col = a.column(0);
let mut it = col.iter();
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert!(it.next().is_none());
let col_test = |mut it: MatrixIter<_, _, _, _>| {
assert_eq!(*it.next().unwrap(), 1.0);
assert_eq!(*it.next().unwrap(), 4.0);
assert!(it.next().is_none());
};
col_test(col.iter());
col_test(col.into_iter());
let col = a.column(1);
let mut it = col.iter();
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert!(it.next().is_none());
let col_test = |mut it: MatrixIter<_, _, _, _>| {
assert_eq!(*it.next().unwrap(), 2.0);
assert_eq!(*it.next().unwrap(), 5.0);
assert!(it.next().is_none());
};
col_test(col.iter());
col_test(col.into_iter());
let col = a.column(2);
let mut it = col.iter();
assert_eq!(*it.next().unwrap(), 3.0);
assert_eq!(*it.next().unwrap(), 6.0);
assert!(it.next().is_none());
let col_test = |mut it: MatrixIter<_, _, _, _>| {
assert_eq!(*it.next().unwrap(), 3.0);
assert_eq!(*it.next().unwrap(), 6.0);
assert!(it.next().is_none());
};
col_test(col.iter());
col_test(col.into_iter());
}
#[test]
fn iter_mut() {
let mut a = Matrix2x3::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
for v in a.iter_mut() {
*v *= 2.0;
}
assert_eq!(a, Matrix2x3::new(2.0, 4.0, 6.0, 8.0, 10.0, 12.0));
let view: MatrixViewMut2x3<_> = MatrixViewMut2x3::from(&mut a);
for v in view.into_iter() {
*v *= 2.0;
}
assert_eq!(a, Matrix2x3::new(4.0, 8.0, 12.0, 16.0, 20.0, 24.0));
}
#[test]