Take the output matrix by mutable reference instead of ownership in `HStack`/`VStack` to facilitate implementing lazy stacking.
This commit is contained in:
parent
be45282263
commit
aec3ae2d53
|
@ -95,12 +95,13 @@ mod vstack_impl {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct VStack<T, R, C, S, R2> {
|
||||
out: Matrix<T, R, C, S>,
|
||||
pub struct VStack<'a, T, R, C, S, R2> {
|
||||
out: &'a mut Matrix<T, R, C, S>,
|
||||
current_row: R2,
|
||||
}
|
||||
|
||||
impl<
|
||||
'a,
|
||||
T: Scalar,
|
||||
R1: Dim + DimAdd<Const<R2>>,
|
||||
C1: Dim,
|
||||
|
@ -109,17 +110,14 @@ mod vstack_impl {
|
|||
S2: RawStorage<T, Const<R2>, C2>,
|
||||
R3: Dim + DimAdd<Const<R2>>,
|
||||
const R2: usize,
|
||||
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<T, R1, C1, S1, R3>
|
||||
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
|
||||
where
|
||||
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
||||
{
|
||||
type Output = VStack<T, R1, C1, S1, DimSum<R3, Const<R2>>>;
|
||||
type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Const<R2>>>;
|
||||
fn visit(self, x: &Matrix<T, Const<R2>, C2, S2>) -> Self::Output {
|
||||
let (r2, _) = x.shape_generic();
|
||||
let VStack {
|
||||
mut out,
|
||||
current_row,
|
||||
} = self;
|
||||
let VStack { out, current_row } = self;
|
||||
out.fixed_rows_mut::<{ R2 }>(current_row.value())
|
||||
.copy_from::<Const<R2>, C2, S2>(x);
|
||||
let current_row = current_row.add(r2);
|
||||
|
@ -127,6 +125,7 @@ mod vstack_impl {
|
|||
}
|
||||
}
|
||||
impl<
|
||||
'a,
|
||||
T: Scalar,
|
||||
R1: Dim + DimAdd<Dyn>,
|
||||
C1: Dim,
|
||||
|
@ -134,17 +133,14 @@ mod vstack_impl {
|
|||
C2: Dim,
|
||||
S2: RawStorage<T, Dyn, C2>,
|
||||
R3: Dim + DimAdd<Dyn>,
|
||||
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<T, R1, C1, S1, R3>
|
||||
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
|
||||
where
|
||||
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
||||
{
|
||||
type Output = VStack<T, R1, C1, S1, DimSum<R3, Dyn>>;
|
||||
type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Dyn>>;
|
||||
fn visit(self, x: &Matrix<T, Dyn, C2, S2>) -> Self::Output {
|
||||
let (r2, _) = x.shape_generic();
|
||||
let VStack {
|
||||
mut out,
|
||||
current_row,
|
||||
} = self;
|
||||
let VStack { out, current_row } = self;
|
||||
out.rows_mut(current_row.value(), r2.value())
|
||||
.copy_from::<Dyn, C2, S2>(x);
|
||||
let current_row = current_row.add(r2);
|
||||
|
@ -161,9 +157,9 @@ mod vstack_impl {
|
|||
C: Dim,
|
||||
X: Copy
|
||||
+ VisitTuple<VStackShapeInit, Output = VStackShape<R, C>>
|
||||
+ VisitTuple<
|
||||
VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
||||
Output = VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
|
||||
+ for<'a> VisitTuple<
|
||||
VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
||||
Output = VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
|
||||
>,
|
||||
>(
|
||||
x: X,
|
||||
|
@ -173,12 +169,13 @@ mod vstack_impl {
|
|||
{
|
||||
let vstack_shape = VStackShapeInit;
|
||||
let vstack_shape = <X as VisitTuple<_>>::visit(vstack_shape, x);
|
||||
let mut out = Matrix::zeros_generic(vstack_shape.r, vstack_shape.c);
|
||||
let vstack_visitor = VStack {
|
||||
out: Matrix::zeros_generic(vstack_shape.r, vstack_shape.c),
|
||||
out: &mut out,
|
||||
current_row: Const,
|
||||
};
|
||||
let vstack_visitor = <X as VisitTuple<_>>::visit(vstack_visitor, x);
|
||||
vstack_visitor.out
|
||||
let _ = <X as VisitTuple<_>>::visit(vstack_visitor, x);
|
||||
out
|
||||
}
|
||||
}
|
||||
pub use vstack_impl::vstack;
|
||||
|
@ -219,12 +216,13 @@ mod hstack_impl {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct HStack<T, R, C, S, C2> {
|
||||
out: Matrix<T, R, C, S>,
|
||||
pub struct HStack<'a, T, R, C, S, C2> {
|
||||
out: &'a mut Matrix<T, R, C, S>,
|
||||
current_col: C2,
|
||||
}
|
||||
|
||||
impl<
|
||||
'a,
|
||||
T: Scalar,
|
||||
R1: Dim,
|
||||
C1: Dim + DimAdd<Const<C2>>,
|
||||
|
@ -233,17 +231,14 @@ mod hstack_impl {
|
|||
S2: RawStorage<T, R2, Const<C2>>,
|
||||
C3: Dim + DimAdd<Const<C2>>,
|
||||
const C2: usize,
|
||||
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<T, R1, C1, S1, C3>
|
||||
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<'a, T, R1, C1, S1, C3>
|
||||
where
|
||||
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
||||
{
|
||||
type Output = HStack<T, R1, C1, S1, DimSum<C3, Const<C2>>>;
|
||||
type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Const<C2>>>;
|
||||
fn visit(self, x: &Matrix<T, R2, Const<C2>, S2>) -> Self::Output {
|
||||
let (_, c2) = x.shape_generic();
|
||||
let HStack {
|
||||
mut out,
|
||||
current_col,
|
||||
} = self;
|
||||
let HStack { out, current_col } = self;
|
||||
out.fixed_columns_mut::<{ C2 }>(current_col.value())
|
||||
.copy_from::<R2, Const<C2>, S2>(x);
|
||||
let current_col = current_col.add(c2);
|
||||
|
@ -251,6 +246,7 @@ mod hstack_impl {
|
|||
}
|
||||
}
|
||||
impl<
|
||||
'a,
|
||||
T: Scalar,
|
||||
R1: Dim,
|
||||
C1: Dim + DimAdd<Dyn>,
|
||||
|
@ -258,17 +254,14 @@ mod hstack_impl {
|
|||
R2: Dim,
|
||||
S2: RawStorage<T, R2, Dyn>,
|
||||
C3: Dim + DimAdd<Dyn>,
|
||||
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<T, R1, C1, S1, C3>
|
||||
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<'a, T, R1, C1, S1, C3>
|
||||
where
|
||||
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
||||
{
|
||||
type Output = HStack<T, R1, C1, S1, DimSum<C3, Dyn>>;
|
||||
type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Dyn>>;
|
||||
fn visit(self, x: &Matrix<T, R2, Dyn, S2>) -> Self::Output {
|
||||
let (_, c2) = x.shape_generic();
|
||||
let HStack {
|
||||
mut out,
|
||||
current_col,
|
||||
} = self;
|
||||
let HStack { out, current_col } = self;
|
||||
out.columns_mut(current_col.value(), c2.value())
|
||||
.copy_from::<R2, Dyn, S2>(x);
|
||||
let current_col = current_col.add(c2);
|
||||
|
@ -285,9 +278,9 @@ mod hstack_impl {
|
|||
C: Dim,
|
||||
X: Copy
|
||||
+ VisitTuple<HStackShapeInit, Output = HStackShape<R, C>>
|
||||
+ VisitTuple<
|
||||
HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
||||
Output = HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
|
||||
+ for<'a> VisitTuple<
|
||||
HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
||||
Output = HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
|
||||
>,
|
||||
>(
|
||||
x: X,
|
||||
|
@ -297,12 +290,13 @@ mod hstack_impl {
|
|||
{
|
||||
let hstack_shape = HStackShapeInit;
|
||||
let hstack_shape = <X as VisitTuple<_>>::visit(hstack_shape, x);
|
||||
let mut out = Matrix::zeros_generic(hstack_shape.r, hstack_shape.c);
|
||||
let hstack_visitor = HStack {
|
||||
out: Matrix::zeros_generic(hstack_shape.r, hstack_shape.c),
|
||||
out: &mut out,
|
||||
current_col: Const,
|
||||
};
|
||||
let hstack_visitor = <X as VisitTuple<_>>::visit(hstack_visitor, x);
|
||||
hstack_visitor.out
|
||||
let _ = <X as VisitTuple<_>>::visit(hstack_visitor, x);
|
||||
out
|
||||
}
|
||||
}
|
||||
pub use hstack_impl::hstack;
|
||||
|
|
Loading…
Reference in New Issue