From aec3ae2d53e415c7c025152d9536263a0d7a203e Mon Sep 17 00:00:00 2001 From: Avi Weinstock Date: Thu, 2 Feb 2023 16:18:16 -0500 Subject: [PATCH] Take the output matrix by mutable reference instead of ownership in `HStack`/`VStack` to facilitate implementing lazy stacking. --- src/base/stacking.rs | 74 ++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/src/base/stacking.rs b/src/base/stacking.rs index 6d3432ce..0585249c 100644 --- a/src/base/stacking.rs +++ b/src/base/stacking.rs @@ -95,12 +95,13 @@ mod vstack_impl { } } - pub struct VStack { - out: Matrix, + pub struct VStack<'a, T, R, C, S, R2> { + out: &'a mut Matrix, current_row: R2, } impl< + 'a, T: Scalar, R1: Dim + DimAdd>, C1: Dim, @@ -109,17 +110,14 @@ mod vstack_impl { S2: RawStorage, C2>, R3: Dim + DimAdd>, const R2: usize, - > Visitor<&Matrix, C2, S2>> for VStack + > Visitor<&Matrix, C2, S2>> for VStack<'a, T, R1, C1, S1, R3> where ShapeConstraint: SameNumberOfColumns, { - type Output = VStack>>; + type Output = VStack<'a, T, R1, C1, S1, DimSum>>; fn visit(self, x: &Matrix, C2, S2>) -> Self::Output { let (r2, _) = x.shape_generic(); - let VStack { - mut out, - current_row, - } = self; + let VStack { out, current_row } = self; out.fixed_rows_mut::<{ R2 }>(current_row.value()) .copy_from::, C2, S2>(x); let current_row = current_row.add(r2); @@ -127,6 +125,7 @@ mod vstack_impl { } } impl< + 'a, T: Scalar, R1: Dim + DimAdd, C1: Dim, @@ -134,17 +133,14 @@ mod vstack_impl { C2: Dim, S2: RawStorage, R3: Dim + DimAdd, - > Visitor<&Matrix> for VStack + > Visitor<&Matrix> for VStack<'a, T, R1, C1, S1, R3> where ShapeConstraint: SameNumberOfColumns, { - type Output = VStack>; + type Output = VStack<'a, T, R1, C1, S1, DimSum>; fn visit(self, x: &Matrix) -> Self::Output { let (r2, _) = x.shape_generic(); - let VStack { - mut out, - current_row, - } = self; + let VStack { out, current_row } = self; out.rows_mut(current_row.value(), r2.value()) .copy_from::(x); let current_row = current_row.add(r2); @@ -161,9 +157,9 @@ mod vstack_impl { C: Dim, X: Copy + VisitTuple> - + VisitTuple< - VStack>::Buffer, Const<0>>, - Output = VStack>::Buffer, R>, + + for<'a> VisitTuple< + VStack<'a, T, R, C, >::Buffer, Const<0>>, + Output = VStack<'a, T, R, C, >::Buffer, R>, >, >( x: X, @@ -173,12 +169,13 @@ mod vstack_impl { { let vstack_shape = VStackShapeInit; let vstack_shape = >::visit(vstack_shape, x); + let mut out = Matrix::zeros_generic(vstack_shape.r, vstack_shape.c); let vstack_visitor = VStack { - out: Matrix::zeros_generic(vstack_shape.r, vstack_shape.c), + out: &mut out, current_row: Const, }; - let vstack_visitor = >::visit(vstack_visitor, x); - vstack_visitor.out + let _ = >::visit(vstack_visitor, x); + out } } pub use vstack_impl::vstack; @@ -219,12 +216,13 @@ mod hstack_impl { } } - pub struct HStack { - out: Matrix, + pub struct HStack<'a, T, R, C, S, C2> { + out: &'a mut Matrix, current_col: C2, } impl< + 'a, T: Scalar, R1: Dim, C1: Dim + DimAdd>, @@ -233,17 +231,14 @@ mod hstack_impl { S2: RawStorage>, C3: Dim + DimAdd>, const C2: usize, - > Visitor<&Matrix, S2>> for HStack + > Visitor<&Matrix, S2>> for HStack<'a, T, R1, C1, S1, C3> where ShapeConstraint: SameNumberOfRows, { - type Output = HStack>>; + type Output = HStack<'a, T, R1, C1, S1, DimSum>>; fn visit(self, x: &Matrix, S2>) -> Self::Output { let (_, c2) = x.shape_generic(); - let HStack { - mut out, - current_col, - } = self; + let HStack { out, current_col } = self; out.fixed_columns_mut::<{ C2 }>(current_col.value()) .copy_from::, S2>(x); let current_col = current_col.add(c2); @@ -251,6 +246,7 @@ mod hstack_impl { } } impl< + 'a, T: Scalar, R1: Dim, C1: Dim + DimAdd, @@ -258,17 +254,14 @@ mod hstack_impl { R2: Dim, S2: RawStorage, C3: Dim + DimAdd, - > Visitor<&Matrix> for HStack + > Visitor<&Matrix> for HStack<'a, T, R1, C1, S1, C3> where ShapeConstraint: SameNumberOfRows, { - type Output = HStack>; + type Output = HStack<'a, T, R1, C1, S1, DimSum>; fn visit(self, x: &Matrix) -> Self::Output { let (_, c2) = x.shape_generic(); - let HStack { - mut out, - current_col, - } = self; + let HStack { out, current_col } = self; out.columns_mut(current_col.value(), c2.value()) .copy_from::(x); let current_col = current_col.add(c2); @@ -285,9 +278,9 @@ mod hstack_impl { C: Dim, X: Copy + VisitTuple> - + VisitTuple< - HStack>::Buffer, Const<0>>, - Output = HStack>::Buffer, C>, + + for<'a> VisitTuple< + HStack<'a, T, R, C, >::Buffer, Const<0>>, + Output = HStack<'a, T, R, C, >::Buffer, C>, >, >( x: X, @@ -297,12 +290,13 @@ mod hstack_impl { { let hstack_shape = HStackShapeInit; let hstack_shape = >::visit(hstack_shape, x); + let mut out = Matrix::zeros_generic(hstack_shape.r, hstack_shape.c); let hstack_visitor = HStack { - out: Matrix::zeros_generic(hstack_shape.r, hstack_shape.c), + out: &mut out, current_col: Const, }; - let hstack_visitor = >::visit(hstack_visitor, x); - hstack_visitor.out + let _ = >::visit(hstack_visitor, x); + out } } pub use hstack_impl::hstack;