From 317aef574a6af8207e7b248f7e566e2f6157ec8c Mon Sep 17 00:00:00 2001 From: Avi Weinstock Date: Mon, 30 Jan 2023 23:14:20 -0500 Subject: [PATCH] Implement `hstack` and `vstack`. The implementation uses a trait to fold over tuples, summing the dimensions in one direction and checking for equality in the other, and then uses `fixed_{rows,columns}_mut` if the dimensions are static, or `{rows,columns}_mut` if the dimensions are dynamic, together with `copy_from` to construct the output matrix. --- nalgebra-macros/tests/tests.rs | 43 +++++ src/base/mod.rs | 2 + src/base/stacking.rs | 304 +++++++++++++++++++++++++++++++++ 3 files changed, 349 insertions(+) create mode 100644 src/base/stacking.rs diff --git a/nalgebra-macros/tests/tests.rs b/nalgebra-macros/tests/tests.rs index 0e52da1f..f58c6733 100644 --- a/nalgebra-macros/tests/tests.rs +++ b/nalgebra-macros/tests/tests.rs @@ -305,3 +305,46 @@ fn dvector_arbitrary_expressions() { let a_expected = DVector::from_column_slice(&[1 + 2, 2 * 3, 4 * f(5 + 6), 7 - 8 * 9]); assert_eq_and_type!(a, a_expected); } + +#[test] +fn test_stacking() { + use nalgebra::{hstack, vstack, RowVector3, RowVector4}; + assert_eq_and_type!( + vstack((&RowVector3::new(1, 2, 3), &RowVector3::new(4, 5, 6))), + Matrix2x3::new(1, 2, 3, 4, 5, 6) + ); + assert_eq_and_type!( + vstack(( + &RowVector3::new(1, 2, 3), + &RowVector3::new(4, 5, 6), + &RowVector3::new(7, 8, 9) + )), + Matrix3::new(1, 2, 3, 4, 5, 6, 7, 8, 9) + ); + assert_eq_and_type!( + hstack((&Vector3::new(1, 2, 3), &Vector3::new(4, 5, 6))), + Matrix3x2::new(1, 4, 2, 5, 3, 6) + ); + assert_eq_and_type!( + hstack(( + &Vector3::new(1, 2, 3), + &Vector3::new(4, 5, 6), + &Vector3::new(7, 8, 9) + )), + Matrix3::new(1, 4, 7, 2, 5, 8, 3, 6, 9) + ); + assert_eq_and_type!( + vstack(( + &hstack((&DMatrix::identity(3, 3), &Vector3::new(2, 2, 2))), + &RowVector4::new(3, 3, 3, 3) + )), + matrix![1, 0, 0, 2; 0, 1, 0, 2; 0, 0, 1, 2; 3, 3, 3, 3] + ); + assert_eq_and_type!( + vstack(( + &hstack((&DMatrix::identity(3, 3), &dvector![2, 2, 2],)), + &dvector![3, 3, 3, 3].transpose(), + )), + dmatrix![1, 0, 0, 2; 0, 1, 0, 2; 0, 0, 1, 2; 3, 3, 3, 3] + ); +} diff --git a/src/base/mod.rs b/src/base/mod.rs index 0f09cc33..c2b17d17 100644 --- a/src/base/mod.rs +++ b/src/base/mod.rs @@ -8,6 +8,7 @@ pub mod default_allocator; pub mod dimension; pub mod iter; mod ops; +mod stacking; pub mod storage; mod alias; @@ -61,6 +62,7 @@ pub use self::alias_slice::*; pub use self::alias_view::*; pub use self::array_storage::*; pub use self::matrix_view::*; +pub use self::stacking::{hstack, vstack}; pub use self::storage::*; #[cfg(any(feature = "std", feature = "alloc"))] pub use self::vec_storage::*; diff --git a/src/base/stacking.rs b/src/base/stacking.rs new file mode 100644 index 00000000..3e249152 --- /dev/null +++ b/src/base/stacking.rs @@ -0,0 +1,304 @@ +///! Utilities for stacking matrices horizontally and vertically. +use crate::{ + base::allocator::Allocator, + constraint::{DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint}, + Const, DefaultAllocator, Dim, DimAdd, DimSum, Dyn, Matrix, RawStorage, RawStorageMut, Scalar, +}; +use num_traits::Zero; + +/// A visitor for each folding over each element of a tuple. +pub trait Visitor { + /// The output type of this step. + type Output; + /// Visits an element of a tuple. + fn visit(self, x: A) -> Self::Output; +} + +/// The driver for visiting each element of a tuple. +pub trait VisitTuple { + /// The output type of the fold. + type Output; + /// Visits each element of a tuple. + fn visit(visitor: F, x: Self) -> Self::Output; +} + +macro_rules! impl_visit_tuple { + ($($is:ident),*) => { + impl_visit_tuple!(__GENERATE_TAILS, [$($is),*]); + }; + (__GENERATE_TAILS, [$i:ident]) => { + impl_visit_tuple!(__GENERATE_CLAUSE, [$i]); + }; + (__GENERATE_TAILS, [$i:ident, $($is:ident),*]) => { + impl_visit_tuple!(__GENERATE_CLAUSE, [$i, $($is),*]); + impl_visit_tuple!(__GENERATE_TAILS, [$($is),*]); + }; + (__GENERATE_CLAUSE, [$i:ident]) => { + impl<$i, Func: Visitor<$i>> VisitTuple for ($i,) { + type Output = >::Output; + #[allow(non_snake_case)] + fn visit(visitor: Func, ($i,): Self) -> Self::Output { + visitor.visit($i) + } + } + }; + (__GENERATE_CLAUSE, [$i:ident, $($is:ident),*]) => { + impl<$i, $($is,)* Func: Visitor<$i>> VisitTuple for ($i, $($is),*) + where ($($is,)*): VisitTuple<>::Output> + { + type Output = <($($is,)*) as VisitTuple<>::Output>>::Output; + #[allow(non_snake_case)] + fn visit(visitor: Func, ($i, $($is),*): Self) -> Self::Output { + VisitTuple::visit(visitor.visit($i), ($($is,)*)) + } + } + }; +} + +impl_visit_tuple!(H, G, F, E, D, C, B, A); + +mod vstack_impl { + use super::*; + #[derive(Clone, Copy)] + pub struct VStackShapeInit; + + #[derive(Clone, Copy)] + pub struct VStackShape { + r: R, + c: C, + } + + impl> Visitor<&Matrix> + for VStackShapeInit + { + type Output = VStackShape; + fn visit(self, x: &Matrix) -> Self::Output { + let (r, c) = x.shape_generic(); + VStackShape { r, c } + } + } + impl, C1: Dim, R2: Dim, C2: Dim, S2: RawStorage> + Visitor<&Matrix> for VStackShape + where + ShapeConstraint: SameNumberOfColumns, + { + type Output = + VStackShape, >::Representative>; + fn visit(self, x: &Matrix) -> Self::Output { + let (r, c) = x.shape_generic(); + VStackShape { + r: self.r.add(r), + c: >::Representative::from_usize(c.value()), + } + } + } + + pub struct VStack { + out: Matrix, + current_row: R2, + } + + impl< + T: Scalar, + R1: Dim + DimAdd>, + C1: Dim, + S1: RawStorageMut, + const R2: usize, + C2: Dim, + S2: RawStorage, C2>, + R3: Dim + DimAdd>, + > Visitor<&Matrix, C2, S2>> for VStack + where + ShapeConstraint: SameNumberOfColumns, + { + type Output = VStack>>; + fn visit(self, x: &Matrix, C2, S2>) -> Self::Output { + let (r2, _) = x.shape_generic(); + let VStack { + mut out, + current_row, + } = self; + out.fixed_rows_mut::<{ R2 }>(current_row.value()) + .copy_from::, C2, S2>(x); + let current_row = current_row.add(r2); + VStack { out, current_row } + } + } + impl< + T: Scalar, + R1: Dim + DimAdd, + C1: Dim, + S1: RawStorageMut, + C2: Dim, + S2: RawStorage, + R3: Dim + DimAdd, + > Visitor<&Matrix> for VStack + where + ShapeConstraint: SameNumberOfColumns, + { + type Output = VStack>; + fn visit(self, x: &Matrix) -> Self::Output { + let (r2, _) = x.shape_generic(); + let VStack { + mut out, + current_row, + } = self; + out.rows_mut(current_row.value(), r2.value()) + .copy_from::(x); + let current_row = current_row.add(r2); + VStack { out, current_row } + } + } + + /// Stack a tuple of references to matrices with equal column counts vertically, yielding a + /// matrix with every row of the input matrices. + pub fn vstack< + T: Scalar + Zero, + R: Dim, + C: Dim, + X: Copy + + VisitTuple> + + VisitTuple< + VStack>::Buffer, Const<0>>, + Output = VStack>::Buffer, R>, + >, + >( + x: X, + ) -> Matrix>::Buffer> + where + DefaultAllocator: Allocator, + { + let vstack_shape = VStackShapeInit; + let vstack_shape = >::visit(vstack_shape, x); + let vstack_visitor = VStack { + out: Matrix::zeros_generic(vstack_shape.r, vstack_shape.c), + current_row: Const, + }; + let vstack_visitor = >::visit(vstack_visitor, x); + vstack_visitor.out + } +} +pub use vstack_impl::vstack; + +mod hstack_impl { + use super::*; + #[derive(Clone, Copy)] + pub struct HStackShapeInit; + + #[derive(Clone, Copy)] + pub struct HStackShape { + r: R, + c: C, + } + + impl> Visitor<&Matrix> + for HStackShapeInit + { + type Output = HStackShape; + fn visit(self, x: &Matrix) -> Self::Output { + let (r, c) = x.shape_generic(); + HStackShape { r, c } + } + } + impl, R2: Dim, C2: Dim, S2: RawStorage> + Visitor<&Matrix> for HStackShape + where + ShapeConstraint: SameNumberOfRows, + { + type Output = + HStackShape<>::Representative, DimSum>; + fn visit(self, x: &Matrix) -> Self::Output { + let (r, c) = x.shape_generic(); + HStackShape { + r: >::Representative::from_usize(r.value()), + c: self.c.add(c), + } + } + } + + pub struct HStack { + out: Matrix, + current_col: C2, + } + + impl< + T: Scalar, + R1: Dim, + C1: Dim + DimAdd>, + S1: RawStorageMut, + R2: Dim, + const C2: usize, + S2: RawStorage>, + C3: Dim + DimAdd>, + > Visitor<&Matrix, S2>> for HStack + where + ShapeConstraint: SameNumberOfRows, + { + type Output = HStack>>; + fn visit(self, x: &Matrix, S2>) -> Self::Output { + let (_, c2) = x.shape_generic(); + let HStack { + mut out, + current_col, + } = self; + out.fixed_columns_mut::<{ C2 }>(current_col.value()) + .copy_from::, S2>(x); + let current_col = current_col.add(c2); + HStack { out, current_col } + } + } + impl< + T: Scalar, + R1: Dim, + C1: Dim + DimAdd, + S1: RawStorageMut, + R2: Dim, + S2: RawStorage, + C3: Dim + DimAdd, + > Visitor<&Matrix> for HStack + where + ShapeConstraint: SameNumberOfRows, + { + type Output = HStack>; + fn visit(self, x: &Matrix) -> Self::Output { + let (_, c2) = x.shape_generic(); + let HStack { + mut out, + current_col, + } = self; + out.columns_mut(current_col.value(), c2.value()) + .copy_from::(x); + let current_col = current_col.add(c2); + HStack { out, current_col } + } + } + + /// Stack a tuple of references to matrices with equal row counts horizontally, yielding a + /// matrix with every column of the input matrices. + pub fn hstack< + T: Scalar + Zero, + R: Dim, + C: Dim, + X: Copy + + VisitTuple> + + VisitTuple< + HStack>::Buffer, Const<0>>, + Output = HStack>::Buffer, C>, + >, + >( + x: X, + ) -> Matrix>::Buffer> + where + DefaultAllocator: Allocator, + { + let hstack_shape = HStackShapeInit; + let hstack_shape = >::visit(hstack_shape, x); + let hstack_visitor = HStack { + out: Matrix::zeros_generic(hstack_shape.r, hstack_shape.c), + current_col: Const, + }; + let hstack_visitor = >::visit(hstack_visitor, x); + hstack_visitor.out + } +} +pub use hstack_impl::hstack;