Take the output matrix by mutable reference instead of ownership in HStack/VStack to facilitate implementing lazy stacking.

This commit is contained in:
Avi Weinstock 2023-02-02 16:18:16 -05:00
parent be45282263
commit aec3ae2d53

View File

@ -95,12 +95,13 @@ mod vstack_impl {
}
}
pub struct VStack<T, R, C, S, R2> {
out: Matrix<T, R, C, S>,
pub struct VStack<'a, T, R, C, S, R2> {
out: &'a mut Matrix<T, R, C, S>,
current_row: R2,
}
impl<
'a,
T: Scalar,
R1: Dim + DimAdd<Const<R2>>,
C1: Dim,
@ -109,17 +110,14 @@ mod vstack_impl {
S2: RawStorage<T, Const<R2>, C2>,
R3: Dim + DimAdd<Const<R2>>,
const R2: usize,
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<T, R1, C1, S1, R3>
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
where
ShapeConstraint: SameNumberOfColumns<C1, C2>,
{
type Output = VStack<T, R1, C1, S1, DimSum<R3, Const<R2>>>;
type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Const<R2>>>;
fn visit(self, x: &Matrix<T, Const<R2>, C2, S2>) -> Self::Output {
let (r2, _) = x.shape_generic();
let VStack {
mut out,
current_row,
} = self;
let VStack { out, current_row } = self;
out.fixed_rows_mut::<{ R2 }>(current_row.value())
.copy_from::<Const<R2>, C2, S2>(x);
let current_row = current_row.add(r2);
@ -127,6 +125,7 @@ mod vstack_impl {
}
}
impl<
'a,
T: Scalar,
R1: Dim + DimAdd<Dyn>,
C1: Dim,
@ -134,17 +133,14 @@ mod vstack_impl {
C2: Dim,
S2: RawStorage<T, Dyn, C2>,
R3: Dim + DimAdd<Dyn>,
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<T, R1, C1, S1, R3>
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
where
ShapeConstraint: SameNumberOfColumns<C1, C2>,
{
type Output = VStack<T, R1, C1, S1, DimSum<R3, Dyn>>;
type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Dyn>>;
fn visit(self, x: &Matrix<T, Dyn, C2, S2>) -> Self::Output {
let (r2, _) = x.shape_generic();
let VStack {
mut out,
current_row,
} = self;
let VStack { out, current_row } = self;
out.rows_mut(current_row.value(), r2.value())
.copy_from::<Dyn, C2, S2>(x);
let current_row = current_row.add(r2);
@ -161,9 +157,9 @@ mod vstack_impl {
C: Dim,
X: Copy
+ VisitTuple<VStackShapeInit, Output = VStackShape<R, C>>
+ VisitTuple<
VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
Output = VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
+ for<'a> VisitTuple<
VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
Output = VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
>,
>(
x: X,
@ -173,12 +169,13 @@ mod vstack_impl {
{
let vstack_shape = VStackShapeInit;
let vstack_shape = <X as VisitTuple<_>>::visit(vstack_shape, x);
let mut out = Matrix::zeros_generic(vstack_shape.r, vstack_shape.c);
let vstack_visitor = VStack {
out: Matrix::zeros_generic(vstack_shape.r, vstack_shape.c),
out: &mut out,
current_row: Const,
};
let vstack_visitor = <X as VisitTuple<_>>::visit(vstack_visitor, x);
vstack_visitor.out
let _ = <X as VisitTuple<_>>::visit(vstack_visitor, x);
out
}
}
pub use vstack_impl::vstack;
@ -219,12 +216,13 @@ mod hstack_impl {
}
}
pub struct HStack<T, R, C, S, C2> {
out: Matrix<T, R, C, S>,
pub struct HStack<'a, T, R, C, S, C2> {
out: &'a mut Matrix<T, R, C, S>,
current_col: C2,
}
impl<
'a,
T: Scalar,
R1: Dim,
C1: Dim + DimAdd<Const<C2>>,
@ -233,17 +231,14 @@ mod hstack_impl {
S2: RawStorage<T, R2, Const<C2>>,
C3: Dim + DimAdd<Const<C2>>,
const C2: usize,
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<T, R1, C1, S1, C3>
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<'a, T, R1, C1, S1, C3>
where
ShapeConstraint: SameNumberOfRows<R1, R2>,
{
type Output = HStack<T, R1, C1, S1, DimSum<C3, Const<C2>>>;
type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Const<C2>>>;
fn visit(self, x: &Matrix<T, R2, Const<C2>, S2>) -> Self::Output {
let (_, c2) = x.shape_generic();
let HStack {
mut out,
current_col,
} = self;
let HStack { out, current_col } = self;
out.fixed_columns_mut::<{ C2 }>(current_col.value())
.copy_from::<R2, Const<C2>, S2>(x);
let current_col = current_col.add(c2);
@ -251,6 +246,7 @@ mod hstack_impl {
}
}
impl<
'a,
T: Scalar,
R1: Dim,
C1: Dim + DimAdd<Dyn>,
@ -258,17 +254,14 @@ mod hstack_impl {
R2: Dim,
S2: RawStorage<T, R2, Dyn>,
C3: Dim + DimAdd<Dyn>,
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<T, R1, C1, S1, C3>
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<'a, T, R1, C1, S1, C3>
where
ShapeConstraint: SameNumberOfRows<R1, R2>,
{
type Output = HStack<T, R1, C1, S1, DimSum<C3, Dyn>>;
type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Dyn>>;
fn visit(self, x: &Matrix<T, R2, Dyn, S2>) -> Self::Output {
let (_, c2) = x.shape_generic();
let HStack {
mut out,
current_col,
} = self;
let HStack { out, current_col } = self;
out.columns_mut(current_col.value(), c2.value())
.copy_from::<R2, Dyn, S2>(x);
let current_col = current_col.add(c2);
@ -285,9 +278,9 @@ mod hstack_impl {
C: Dim,
X: Copy
+ VisitTuple<HStackShapeInit, Output = HStackShape<R, C>>
+ VisitTuple<
HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
Output = HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
+ for<'a> VisitTuple<
HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
Output = HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
>,
>(
x: X,
@ -297,12 +290,13 @@ mod hstack_impl {
{
let hstack_shape = HStackShapeInit;
let hstack_shape = <X as VisitTuple<_>>::visit(hstack_shape, x);
let mut out = Matrix::zeros_generic(hstack_shape.r, hstack_shape.c);
let hstack_visitor = HStack {
out: Matrix::zeros_generic(hstack_shape.r, hstack_shape.c),
out: &mut out,
current_col: Const,
};
let hstack_visitor = <X as VisitTuple<_>>::visit(hstack_visitor, x);
hstack_visitor.out
let _ = <X as VisitTuple<_>>::visit(hstack_visitor, x);
out
}
}
pub use hstack_impl::hstack;