cleanups and add tests

This commit is contained in:
geo-ant 2022-10-20 19:07:22 +02:00 committed by Sébastien Crozet
parent f850ed535e
commit 7ac536be07
3 changed files with 56 additions and 58 deletions

View File

@ -475,23 +475,26 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> DoubleEndedI
} }
} }
impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Producer for ColumnIterMut<'a,T,R,C,S> impl<'a, T: Scalar, R: Dim, C: Dim, S: 'a + RawStorageMut<T, R, C>> Producer
where T : Send + Sync + Debug + PartialEq + Clone, for ColumnIterMut<'a, T, R, C, S>
S: Send + Sync { where
T: Send + Sync + Debug + PartialEq + Clone,
S: Send + Sync,
{
type Item = MatrixSliceMut<'a, T, R, U1, S::RStride, S::CStride>; type Item = MatrixSliceMut<'a, T, R, U1, S::RStride, S::CStride>;
type IntoIter = ColumnIterMut<'a,T,R,C,S>; type IntoIter = ColumnIterMut<'a, T, R, C, S>;
fn into_iter(self) -> Self::IntoIter { fn into_iter(self) -> Self::IntoIter {
self self
} }
fn split_at(self, index: usize) -> (Self, Self) { fn split_at(self, index: usize) -> (Self, Self) {
// the index is relative to the size of this current iterator // the index is relative to the size of this current iterator
// it will always start at zero // it will always start at zero
let pmat : * mut _ = self.mat; let pmat: *mut _ = self.mat;
let left = Self { let left = Self {
mat: unsafe {&mut *pmat}, mat: unsafe { &mut *pmat },
range: self.range.start..(self.range.start + index), range: self.range.start..(self.range.start + index),
}; };
@ -502,9 +505,3 @@ where T : Send + Sync + Debug + PartialEq + Clone,
(left, right) (left, right)
} }
} }
fn test_send<T: Send>(_: T) {}
fn something(mut matrix: DMatrix<f32>) {
test_send(matrix.column_iter_mut());
}

View File

@ -3,24 +3,18 @@
use core::{ use core::{
fmt::Debug, fmt::Debug,
iter::{Skip, Take},
marker::PhantomData,
ops::Range,
}; };
use std::os::unix::prelude::AsRawFd;
use rayon::{ use rayon::{
iter::plumbing::{bridge, Producer}, iter::plumbing::{bridge},
prelude::*, prelude::*,
}; };
use crate::{ use crate::{
iter::{ColumnIter, ColumnIterMut}, Const, DMatrix, Dim, Dynamic, Matrix, MatrixSlice, MatrixSliceMut, iter::{ColumnIter, ColumnIterMut}, DMatrix, Dim, Matrix, MatrixSlice, MatrixSliceMut,
RawStorage, RawStorageMut, U1, SliceStorageMut, RawStorage, RawStorageMut, U1,
}; };
use super::conversion;
/// a rayon parallel iterator over the columns of a matrix /// a rayon parallel iterator over the columns of a matrix
pub struct ParColumnIter<'a, T, R: Dim, Cols: Dim, S: RawStorage<T, R, Cols>> { pub struct ParColumnIter<'a, T, R: Dim, Cols: Dim, S: RawStorage<T, R, Cols>> {
mat: &'a Matrix<T, R, Cols, S>, mat: &'a Matrix<T, R, Cols, S>,
@ -146,20 +140,3 @@ where
ParColumnIterMut::new(self) ParColumnIterMut::new(self)
} }
} }
#[test]
fn test_mut_parallel_iter() {
let mut matrix = DMatrix::<f32>::zeros(4, 3);
matrix.par_column_iter_mut().enumerate().for_each(|(idx,mut col)| col[idx]=1f32);
let identity = DMatrix::<f32>::identity(4, 3);
assert_eq!(matrix,identity);
}
fn try_some_stuff() {
let mut mat = DMatrix::<f32>::zeros(3, 4);
let _left = mat.columns_mut(0, 1);
let _right = mat.columns_mut(1, 3);
}

View File

@ -1145,19 +1145,41 @@ fn column_iteration() {
23,24,25; 23,24,25;
33,34,35; 33,34,35;
]; ];
// not using enumerate on purpose let mut col_iter = dmat.column_iter();
let mut idx = 0; assert_eq!(col_iter.next(),Some(dmat.column(0)));
for col in dmat.column_iter() { assert_eq!(col_iter.next(),Some(dmat.column(1)));
assert_eq!(dmat.column(idx),col); assert_eq!(col_iter.next(),Some(dmat.column(2)));
idx += 1; assert_eq!(col_iter.next(),None);
}
// statically sized matrix // statically sized matrix
let smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0]; let smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0];
let mut idx = 0; let mut col_iter = smat.column_iter();
for col in smat.column_iter() { assert_eq!(col_iter.next(),Some(smat.column(0)));
assert_eq!(smat.column(idx),col); assert_eq!(col_iter.next(),Some(smat.column(1)));
idx += 1; assert_eq!(col_iter.next(),None);
} }
#[test]
fn column_iteration_mut() {
let mut dmat = nalgebra::dmatrix![
13,14,15;
23,24,25;
33,34,35;
];
let mut cloned = dmat.clone();
let mut col_iter = dmat.column_iter_mut();
assert_eq!(col_iter.next(),Some(cloned.column_mut(0)));
assert_eq!(col_iter.next(),Some(cloned.column_mut(1)));
assert_eq!(col_iter.next(),Some(cloned.column_mut(2)));
assert_eq!(col_iter.next(),None);
// statically sized matrix
let mut smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0];
let mut cloned = smat.clone();
let mut col_iter = smat.column_iter_mut();
assert_eq!(col_iter.next(),Some(cloned.column_mut(0)));
assert_eq!(col_iter.next(),Some(cloned.column_mut(1)));
assert_eq!(col_iter.next(),None);
} }
#[test] #[test]
@ -1181,7 +1203,7 @@ fn column_iteration_double_ended() {
fn parallel_column_iteration() { fn parallel_column_iteration() {
use rayon::prelude::*; use rayon::prelude::*;
use nalgebra::{dmatrix,dvector}; use nalgebra::{dmatrix,dvector};
let dmat = dmatrix![ let dmat : DMatrix<f64> = dmatrix![
13.,14.; 13.,14.;
23.,24.; 23.,24.;
33.,34.; 33.,34.;
@ -1193,15 +1215,11 @@ fn parallel_column_iteration() {
}); });
// test that a more complex expression produces the same // test that a more complex expression produces the same
// result as the serial equivalent // result as the serial equivalent
let par_result :f64 = dmat.par_column_iter().map(|col| col.norm()).sum(); let par_result : f64 = dmat.par_column_iter().map(|col| col.norm()).sum();
let ser_result = dmat.column_iter().map(|col| col.norm()).sum(); let ser_result : f64= dmat.column_iter().map(|col| col.norm()).sum();
assert_eq!(par_result,ser_result); assert_eq!(par_result,ser_result);
} }
#[test]
fn column_iteration_mut() {
todo!();
}
#[test] #[test]
fn colum_iteration_mut_double_ended() { fn colum_iteration_mut_double_ended() {
@ -1223,5 +1241,11 @@ fn colum_iteration_mut_double_ended() {
#[test] #[test]
fn parallel_column_iteration_mut() { fn parallel_column_iteration_mut() {
todo!() use rayon::prelude::*;
let mut first = DMatrix::<f32>::zeros(400,300);
let mut second = DMatrix::<f32>::zeros(400,300);
first.column_iter_mut().enumerate().for_each(|(idx,mut col)|col[idx]=1.);
second.par_column_iter_mut().enumerate().for_each(|(idx,mut col)| col[idx]=1.);
assert_eq!(first,second);
assert_eq!(second,DMatrix::identity(400,300));
} }