Take the output matrix by mutable reference instead of ownership in `HStack`/`VStack` to facilitate implementing lazy stacking.

This commit is contained in:
Avi Weinstock 2023-02-02 16:18:16 -05:00
parent be45282263
commit aec3ae2d53
1 changed files with 34 additions and 40 deletions

View File

@ -95,12 +95,13 @@ mod vstack_impl {
} }
} }
pub struct VStack<T, R, C, S, R2> { pub struct VStack<'a, T, R, C, S, R2> {
out: Matrix<T, R, C, S>, out: &'a mut Matrix<T, R, C, S>,
current_row: R2, current_row: R2,
} }
impl< impl<
'a,
T: Scalar, T: Scalar,
R1: Dim + DimAdd<Const<R2>>, R1: Dim + DimAdd<Const<R2>>,
C1: Dim, C1: Dim,
@ -109,17 +110,14 @@ mod vstack_impl {
S2: RawStorage<T, Const<R2>, C2>, S2: RawStorage<T, Const<R2>, C2>,
R3: Dim + DimAdd<Const<R2>>, R3: Dim + DimAdd<Const<R2>>,
const R2: usize, const R2: usize,
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<T, R1, C1, S1, R3> > Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
where where
ShapeConstraint: SameNumberOfColumns<C1, C2>, ShapeConstraint: SameNumberOfColumns<C1, C2>,
{ {
type Output = VStack<T, R1, C1, S1, DimSum<R3, Const<R2>>>; type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Const<R2>>>;
fn visit(self, x: &Matrix<T, Const<R2>, C2, S2>) -> Self::Output { fn visit(self, x: &Matrix<T, Const<R2>, C2, S2>) -> Self::Output {
let (r2, _) = x.shape_generic(); let (r2, _) = x.shape_generic();
let VStack { let VStack { out, current_row } = self;
mut out,
current_row,
} = self;
out.fixed_rows_mut::<{ R2 }>(current_row.value()) out.fixed_rows_mut::<{ R2 }>(current_row.value())
.copy_from::<Const<R2>, C2, S2>(x); .copy_from::<Const<R2>, C2, S2>(x);
let current_row = current_row.add(r2); let current_row = current_row.add(r2);
@ -127,6 +125,7 @@ mod vstack_impl {
} }
} }
impl< impl<
'a,
T: Scalar, T: Scalar,
R1: Dim + DimAdd<Dyn>, R1: Dim + DimAdd<Dyn>,
C1: Dim, C1: Dim,
@ -134,17 +133,14 @@ mod vstack_impl {
C2: Dim, C2: Dim,
S2: RawStorage<T, Dyn, C2>, S2: RawStorage<T, Dyn, C2>,
R3: Dim + DimAdd<Dyn>, R3: Dim + DimAdd<Dyn>,
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<T, R1, C1, S1, R3> > Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
where where
ShapeConstraint: SameNumberOfColumns<C1, C2>, ShapeConstraint: SameNumberOfColumns<C1, C2>,
{ {
type Output = VStack<T, R1, C1, S1, DimSum<R3, Dyn>>; type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Dyn>>;
fn visit(self, x: &Matrix<T, Dyn, C2, S2>) -> Self::Output { fn visit(self, x: &Matrix<T, Dyn, C2, S2>) -> Self::Output {
let (r2, _) = x.shape_generic(); let (r2, _) = x.shape_generic();
let VStack { let VStack { out, current_row } = self;
mut out,
current_row,
} = self;
out.rows_mut(current_row.value(), r2.value()) out.rows_mut(current_row.value(), r2.value())
.copy_from::<Dyn, C2, S2>(x); .copy_from::<Dyn, C2, S2>(x);
let current_row = current_row.add(r2); let current_row = current_row.add(r2);
@ -161,9 +157,9 @@ mod vstack_impl {
C: Dim, C: Dim,
X: Copy X: Copy
+ VisitTuple<VStackShapeInit, Output = VStackShape<R, C>> + VisitTuple<VStackShapeInit, Output = VStackShape<R, C>>
+ VisitTuple< + for<'a> VisitTuple<
VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>, VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
Output = VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>, Output = VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
>, >,
>( >(
x: X, x: X,
@ -173,12 +169,13 @@ mod vstack_impl {
{ {
let vstack_shape = VStackShapeInit; let vstack_shape = VStackShapeInit;
let vstack_shape = <X as VisitTuple<_>>::visit(vstack_shape, x); let vstack_shape = <X as VisitTuple<_>>::visit(vstack_shape, x);
let mut out = Matrix::zeros_generic(vstack_shape.r, vstack_shape.c);
let vstack_visitor = VStack { let vstack_visitor = VStack {
out: Matrix::zeros_generic(vstack_shape.r, vstack_shape.c), out: &mut out,
current_row: Const, current_row: Const,
}; };
let vstack_visitor = <X as VisitTuple<_>>::visit(vstack_visitor, x); let _ = <X as VisitTuple<_>>::visit(vstack_visitor, x);
vstack_visitor.out out
} }
} }
pub use vstack_impl::vstack; pub use vstack_impl::vstack;
@ -219,12 +216,13 @@ mod hstack_impl {
} }
} }
pub struct HStack<T, R, C, S, C2> { pub struct HStack<'a, T, R, C, S, C2> {
out: Matrix<T, R, C, S>, out: &'a mut Matrix<T, R, C, S>,
current_col: C2, current_col: C2,
} }
impl< impl<
'a,
T: Scalar, T: Scalar,
R1: Dim, R1: Dim,
C1: Dim + DimAdd<Const<C2>>, C1: Dim + DimAdd<Const<C2>>,
@ -233,17 +231,14 @@ mod hstack_impl {
S2: RawStorage<T, R2, Const<C2>>, S2: RawStorage<T, R2, Const<C2>>,
C3: Dim + DimAdd<Const<C2>>, C3: Dim + DimAdd<Const<C2>>,
const C2: usize, const C2: usize,
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<T, R1, C1, S1, C3> > Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<'a, T, R1, C1, S1, C3>
where where
ShapeConstraint: SameNumberOfRows<R1, R2>, ShapeConstraint: SameNumberOfRows<R1, R2>,
{ {
type Output = HStack<T, R1, C1, S1, DimSum<C3, Const<C2>>>; type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Const<C2>>>;
fn visit(self, x: &Matrix<T, R2, Const<C2>, S2>) -> Self::Output { fn visit(self, x: &Matrix<T, R2, Const<C2>, S2>) -> Self::Output {
let (_, c2) = x.shape_generic(); let (_, c2) = x.shape_generic();
let HStack { let HStack { out, current_col } = self;
mut out,
current_col,
} = self;
out.fixed_columns_mut::<{ C2 }>(current_col.value()) out.fixed_columns_mut::<{ C2 }>(current_col.value())
.copy_from::<R2, Const<C2>, S2>(x); .copy_from::<R2, Const<C2>, S2>(x);
let current_col = current_col.add(c2); let current_col = current_col.add(c2);
@ -251,6 +246,7 @@ mod hstack_impl {
} }
} }
impl< impl<
'a,
T: Scalar, T: Scalar,
R1: Dim, R1: Dim,
C1: Dim + DimAdd<Dyn>, C1: Dim + DimAdd<Dyn>,
@ -258,17 +254,14 @@ mod hstack_impl {
R2: Dim, R2: Dim,
S2: RawStorage<T, R2, Dyn>, S2: RawStorage<T, R2, Dyn>,
C3: Dim + DimAdd<Dyn>, C3: Dim + DimAdd<Dyn>,
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<T, R1, C1, S1, C3> > Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<'a, T, R1, C1, S1, C3>
where where
ShapeConstraint: SameNumberOfRows<R1, R2>, ShapeConstraint: SameNumberOfRows<R1, R2>,
{ {
type Output = HStack<T, R1, C1, S1, DimSum<C3, Dyn>>; type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Dyn>>;
fn visit(self, x: &Matrix<T, R2, Dyn, S2>) -> Self::Output { fn visit(self, x: &Matrix<T, R2, Dyn, S2>) -> Self::Output {
let (_, c2) = x.shape_generic(); let (_, c2) = x.shape_generic();
let HStack { let HStack { out, current_col } = self;
mut out,
current_col,
} = self;
out.columns_mut(current_col.value(), c2.value()) out.columns_mut(current_col.value(), c2.value())
.copy_from::<R2, Dyn, S2>(x); .copy_from::<R2, Dyn, S2>(x);
let current_col = current_col.add(c2); let current_col = current_col.add(c2);
@ -285,9 +278,9 @@ mod hstack_impl {
C: Dim, C: Dim,
X: Copy X: Copy
+ VisitTuple<HStackShapeInit, Output = HStackShape<R, C>> + VisitTuple<HStackShapeInit, Output = HStackShape<R, C>>
+ VisitTuple< + for<'a> VisitTuple<
HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>, HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
Output = HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>, Output = HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
>, >,
>( >(
x: X, x: X,
@ -297,12 +290,13 @@ mod hstack_impl {
{ {
let hstack_shape = HStackShapeInit; let hstack_shape = HStackShapeInit;
let hstack_shape = <X as VisitTuple<_>>::visit(hstack_shape, x); let hstack_shape = <X as VisitTuple<_>>::visit(hstack_shape, x);
let mut out = Matrix::zeros_generic(hstack_shape.r, hstack_shape.c);
let hstack_visitor = HStack { let hstack_visitor = HStack {
out: Matrix::zeros_generic(hstack_shape.r, hstack_shape.c), out: &mut out,
current_col: Const, current_col: Const,
}; };
let hstack_visitor = <X as VisitTuple<_>>::visit(hstack_visitor, x); let _ = <X as VisitTuple<_>>::visit(hstack_visitor, x);
hstack_visitor.out out
} }
} }
pub use hstack_impl::hstack; pub use hstack_impl::hstack;