apply fmt

This commit is contained in:
geo-ant 2022-10-21 08:51:41 +02:00 committed by Sébastien Crozet
parent daade1cf5e
commit a4e28a136e
2 changed files with 114 additions and 92 deletions

View File

@ -1,15 +1,12 @@
//! this module implements parallelators to make matrices work with //! this module implements parallelators to make matrices work with
//! the rayon crate seamlessly //! the rayon crate seamlessly
use core::fmt::Debug;
use rayon::{
iter::plumbing::{bridge},
prelude::*,
};
use crate::{ use crate::{
iter::{ColumnIter, ColumnIterMut}, Dim, Matrix, MatrixSlice, MatrixSliceMut, iter::{ColumnIter, ColumnIterMut},
RawStorage, RawStorageMut, U1, Dim, Matrix, MatrixSlice, MatrixSliceMut, RawStorage, RawStorageMut, U1,
}; };
use core::fmt::Debug;
use rayon::{iter::plumbing::bridge, prelude::*};
/// A rayon parallel iterator over the colums of a matrix /// A rayon parallel iterator over the colums of a matrix
pub struct ParColumnIter<'a, T, R: Dim, Cols: Dim, S: RawStorage<T, R, Cols>> { pub struct ParColumnIter<'a, T, R: Dim, Cols: Dim, S: RawStorage<T, R, Cols>> {
@ -78,29 +75,42 @@ where
} }
/// A rayon parallel iterator through the mutable columns of a matrix /// A rayon parallel iterator through the mutable columns of a matrix
pub struct ParColumnIterMut<'a,T,R:Dim ,Cols:Dim, S:RawStorage<T,R,Cols>+RawStorageMut<T,R,Cols>> { pub struct ParColumnIterMut<
mat : &'a mut Matrix<T,R,Cols,S>, 'a,
T,
R: Dim,
Cols: Dim,
S: RawStorage<T, R, Cols> + RawStorageMut<T, R, Cols>,
> {
mat: &'a mut Matrix<T, R, Cols, S>,
} }
impl<'a,T,R,Cols,S> ParColumnIterMut<'a,T,R,Cols,S> impl<'a, T, R, Cols, S> ParColumnIterMut<'a, T, R, Cols, S>
where R: Dim, Cols : Dim, S:RawStorage<T,R,Cols> + RawStorageMut<T,R,Cols> { where
R: Dim,
Cols: Dim,
S: RawStorage<T, R, Cols> + RawStorageMut<T, R, Cols>,
{
/// create a new parallel iterator for the given matrix /// create a new parallel iterator for the given matrix
fn new(mat : &'a mut Matrix<T,R,Cols,S>) -> Self { fn new(mat: &'a mut Matrix<T, R, Cols, S>) -> Self {
Self { Self { mat }
mat,
}
} }
} }
impl<'a,T,R,Cols,S> ParallelIterator for ParColumnIterMut<'a,T,R,Cols,S> impl<'a, T, R, Cols, S> ParallelIterator for ParColumnIterMut<'a, T, R, Cols, S>
where R: Dim, Cols : Dim, S:RawStorage<T,R,Cols> + RawStorageMut<T,R,Cols>, where
T : Send + Sync + Debug + PartialEq + Clone + 'static, R: Dim,
S : Send + Sync { Cols: Dim,
S: RawStorage<T, R, Cols> + RawStorageMut<T, R, Cols>,
T: Send + Sync + Debug + PartialEq + Clone + 'static,
S: Send + Sync,
{
type Item = MatrixSliceMut<'a, T, R, U1, S::RStride, S::CStride>; type Item = MatrixSliceMut<'a, T, R, U1, S::RStride, S::CStride>;
fn drive_unindexed<C>(self, consumer: C) -> C::Result fn drive_unindexed<C>(self, consumer: C) -> C::Result
where where
C: rayon::iter::plumbing::UnindexedConsumer<Self::Item> { C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
bridge(self,consumer) {
bridge(self, consumer)
} }
fn opt_len(&self) -> Option<usize> { fn opt_len(&self) -> Option<usize> {
@ -108,26 +118,33 @@ S : Send + Sync {
} }
} }
impl<'a, T, R, Cols, S> IndexedParallelIterator for ParColumnIterMut<'a, T, R, Cols, S>
impl<'a,T,R,Cols,S> IndexedParallelIterator for ParColumnIterMut<'a,T,R,Cols,S> where
where R: Dim, Cols : Dim, S:RawStorage<T,R,Cols> + RawStorageMut<T,R,Cols>, R: Dim,
T : Send + Sync + Debug + PartialEq + Clone + 'static, Cols: Dim,
S : Send + Sync { S: RawStorage<T, R, Cols> + RawStorageMut<T, R, Cols>,
T: Send + Sync + Debug + PartialEq + Clone + 'static,
S: Send + Sync,
{
fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result { fn drive<C: rayon::iter::plumbing::Consumer<Self::Item>>(self, consumer: C) -> C::Result {
bridge(self,consumer) bridge(self, consumer)
} }
fn len(&self) -> usize { fn len(&self) -> usize {
self.mat.ncols() self.mat.ncols()
} }
fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output { fn with_producer<CB: rayon::iter::plumbing::ProducerCallback<Self::Item>>(
self,
callback: CB,
) -> CB::Output {
let producer = ColumnIterMut::new(self.mat); let producer = ColumnIterMut::new(self.mat);
callback.callback(producer) callback.callback(producer)
} }
} }
impl<'a, T, R: Dim, Cols: Dim, S: RawStorage<T, R, Cols> + RawStorageMut<T,R,Cols>> Matrix<T, R, Cols, S> impl<'a, T, R: Dim, Cols: Dim, S: RawStorage<T, R, Cols> + RawStorageMut<T, R, Cols>>
Matrix<T, R, Cols, S>
where where
T: Send + Sync + Clone + Debug + PartialEq + 'static, T: Send + Sync + Clone + Debug + PartialEq + 'static,
S: Sync, S: Sync,

View File

@ -1146,17 +1146,17 @@ fn column_iteration() {
33,34,35; 33,34,35;
]; ];
let mut col_iter = dmat.column_iter(); let mut col_iter = dmat.column_iter();
assert_eq!(col_iter.next(),Some(dmat.column(0))); assert_eq!(col_iter.next(), Some(dmat.column(0)));
assert_eq!(col_iter.next(),Some(dmat.column(1))); assert_eq!(col_iter.next(), Some(dmat.column(1)));
assert_eq!(col_iter.next(),Some(dmat.column(2))); assert_eq!(col_iter.next(), Some(dmat.column(2)));
assert_eq!(col_iter.next(),None); assert_eq!(col_iter.next(), None);
// statically sized matrix // statically sized matrix
let smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0]; let smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0];
let mut col_iter = smat.column_iter(); let mut col_iter = smat.column_iter();
assert_eq!(col_iter.next(),Some(smat.column(0))); assert_eq!(col_iter.next(), Some(smat.column(0)));
assert_eq!(col_iter.next(),Some(smat.column(1))); assert_eq!(col_iter.next(), Some(smat.column(1)));
assert_eq!(col_iter.next(),None); assert_eq!(col_iter.next(), None);
} }
#[test] #[test]
@ -1168,18 +1168,18 @@ fn column_iteration_mut() {
]; ];
let mut cloned = dmat.clone(); let mut cloned = dmat.clone();
let mut col_iter = dmat.column_iter_mut(); let mut col_iter = dmat.column_iter_mut();
assert_eq!(col_iter.next(),Some(cloned.column_mut(0))); assert_eq!(col_iter.next(), Some(cloned.column_mut(0)));
assert_eq!(col_iter.next(),Some(cloned.column_mut(1))); assert_eq!(col_iter.next(), Some(cloned.column_mut(1)));
assert_eq!(col_iter.next(),Some(cloned.column_mut(2))); assert_eq!(col_iter.next(), Some(cloned.column_mut(2)));
assert_eq!(col_iter.next(),None); assert_eq!(col_iter.next(), None);
// statically sized matrix // statically sized matrix
let mut smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0]; let mut smat: nalgebra::SMatrix<f64, 2, 2> = nalgebra::matrix![1.0, 2.0; 3.0, 4.0];
let mut cloned = smat.clone(); let mut cloned = smat.clone();
let mut col_iter = smat.column_iter_mut(); let mut col_iter = smat.column_iter_mut();
assert_eq!(col_iter.next(),Some(cloned.column_mut(0))); assert_eq!(col_iter.next(), Some(cloned.column_mut(0)));
assert_eq!(col_iter.next(),Some(cloned.column_mut(1))); assert_eq!(col_iter.next(), Some(cloned.column_mut(1)));
assert_eq!(col_iter.next(),None); assert_eq!(col_iter.next(), None);
} }
#[test] #[test]
@ -1190,37 +1190,36 @@ fn column_iteration_double_ended() {
33,34,35,36,37; 33,34,35,36,37;
]; ];
let mut col_iter = dmat.column_iter(); let mut col_iter = dmat.column_iter();
assert_eq!(col_iter.next(),Some(dmat.column(0))); assert_eq!(col_iter.next(), Some(dmat.column(0)));
assert_eq!(col_iter.next(),Some(dmat.column(1))); assert_eq!(col_iter.next(), Some(dmat.column(1)));
assert_eq!(col_iter.next_back(),Some(dmat.column(4))); assert_eq!(col_iter.next_back(), Some(dmat.column(4)));
assert_eq!(col_iter.next_back(),Some(dmat.column(3))); assert_eq!(col_iter.next_back(), Some(dmat.column(3)));
assert_eq!(col_iter.next(),Some(dmat.column(2))); assert_eq!(col_iter.next(), Some(dmat.column(2)));
assert_eq!(col_iter.next_back(),None); assert_eq!(col_iter.next_back(), None);
assert_eq!(col_iter.next(),None); assert_eq!(col_iter.next(), None);
} }
#[test] #[test]
fn parallel_column_iteration() { fn parallel_column_iteration() {
use nalgebra::dmatrix;
use rayon::prelude::*; use rayon::prelude::*;
use nalgebra::{dmatrix,dvector}; let dmat: DMatrix<f64> = dmatrix![
let dmat : DMatrix<f64> = dmatrix![
13.,14.; 13.,14.;
23.,24.; 23.,24.;
33.,34.; 33.,34.;
]; ];
let cloned = dmat.clone(); let cloned = dmat.clone();
// test that correct columns are iterated over // test that correct columns are iterated over
dmat.par_column_iter().enumerate().for_each(|(idx,col)| { dmat.par_column_iter().enumerate().for_each(|(idx, col)| {
assert_eq!(col,cloned.column(idx)); assert_eq!(col, cloned.column(idx));
}); });
// test that a more complex expression produces the same // test that a more complex expression produces the same
// result as the serial equivalent // result as the serial equivalent
let par_result : f64 = dmat.par_column_iter().map(|col| col.norm()).sum(); let par_result: f64 = dmat.par_column_iter().map(|col| col.norm()).sum();
let ser_result : f64= dmat.column_iter().map(|col| col.norm()).sum(); let ser_result: f64 = dmat.column_iter().map(|col| col.norm()).sum();
assert_eq!(par_result,ser_result); assert_eq!(par_result, ser_result);
} }
#[test] #[test]
fn colum_iteration_mut_double_ended() { fn colum_iteration_mut_double_ended() {
let dmat = nalgebra::dmatrix![ let dmat = nalgebra::dmatrix![
@ -1230,22 +1229,28 @@ fn colum_iteration_mut_double_ended() {
]; ];
let cloned = dmat.clone(); let cloned = dmat.clone();
let mut col_iter = dmat.column_iter(); let mut col_iter = dmat.column_iter();
assert_eq!(col_iter.next(),Some(cloned.column(0))); assert_eq!(col_iter.next(), Some(cloned.column(0)));
assert_eq!(col_iter.next(),Some(cloned.column(1))); assert_eq!(col_iter.next(), Some(cloned.column(1)));
assert_eq!(col_iter.next_back(),Some(cloned.column(4))); assert_eq!(col_iter.next_back(), Some(cloned.column(4)));
assert_eq!(col_iter.next_back(),Some(cloned.column(3))); assert_eq!(col_iter.next_back(), Some(cloned.column(3)));
assert_eq!(col_iter.next(),Some(cloned.column(2))); assert_eq!(col_iter.next(), Some(cloned.column(2)));
assert_eq!(col_iter.next_back(),None); assert_eq!(col_iter.next_back(), None);
assert_eq!(col_iter.next(),None); assert_eq!(col_iter.next(), None);
} }
#[test] #[test]
fn parallel_column_iteration_mut() { fn parallel_column_iteration_mut() {
use rayon::prelude::*; use rayon::prelude::*;
let mut first = DMatrix::<f32>::zeros(400,300); let mut first = DMatrix::<f32>::zeros(400, 300);
let mut second = DMatrix::<f32>::zeros(400,300); let mut second = DMatrix::<f32>::zeros(400, 300);
first.column_iter_mut().enumerate().for_each(|(idx,mut col)|col[idx]=1.); first
second.par_column_iter_mut().enumerate().for_each(|(idx,mut col)| col[idx]=1.); .column_iter_mut()
assert_eq!(first,second); .enumerate()
assert_eq!(second,DMatrix::identity(400,300)); .for_each(|(idx, mut col)| col[idx] = 1.);
second
.par_column_iter_mut()
.enumerate()
.for_each(|(idx, mut col)| col[idx] = 1.);
assert_eq!(first, second);
assert_eq!(second, DMatrix::identity(400, 300));
} }