Take the output matrix by mutable reference instead of ownership in `HStack`/`VStack` to facilitate implementing lazy stacking.
This commit is contained in:
parent
be45282263
commit
aec3ae2d53
|
@ -95,12 +95,13 @@ mod vstack_impl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct VStack<T, R, C, S, R2> {
|
pub struct VStack<'a, T, R, C, S, R2> {
|
||||||
out: Matrix<T, R, C, S>,
|
out: &'a mut Matrix<T, R, C, S>,
|
||||||
current_row: R2,
|
current_row: R2,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<
|
impl<
|
||||||
|
'a,
|
||||||
T: Scalar,
|
T: Scalar,
|
||||||
R1: Dim + DimAdd<Const<R2>>,
|
R1: Dim + DimAdd<Const<R2>>,
|
||||||
C1: Dim,
|
C1: Dim,
|
||||||
|
@ -109,17 +110,14 @@ mod vstack_impl {
|
||||||
S2: RawStorage<T, Const<R2>, C2>,
|
S2: RawStorage<T, Const<R2>, C2>,
|
||||||
R3: Dim + DimAdd<Const<R2>>,
|
R3: Dim + DimAdd<Const<R2>>,
|
||||||
const R2: usize,
|
const R2: usize,
|
||||||
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<T, R1, C1, S1, R3>
|
> Visitor<&Matrix<T, Const<R2>, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
|
||||||
where
|
where
|
||||||
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
||||||
{
|
{
|
||||||
type Output = VStack<T, R1, C1, S1, DimSum<R3, Const<R2>>>;
|
type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Const<R2>>>;
|
||||||
fn visit(self, x: &Matrix<T, Const<R2>, C2, S2>) -> Self::Output {
|
fn visit(self, x: &Matrix<T, Const<R2>, C2, S2>) -> Self::Output {
|
||||||
let (r2, _) = x.shape_generic();
|
let (r2, _) = x.shape_generic();
|
||||||
let VStack {
|
let VStack { out, current_row } = self;
|
||||||
mut out,
|
|
||||||
current_row,
|
|
||||||
} = self;
|
|
||||||
out.fixed_rows_mut::<{ R2 }>(current_row.value())
|
out.fixed_rows_mut::<{ R2 }>(current_row.value())
|
||||||
.copy_from::<Const<R2>, C2, S2>(x);
|
.copy_from::<Const<R2>, C2, S2>(x);
|
||||||
let current_row = current_row.add(r2);
|
let current_row = current_row.add(r2);
|
||||||
|
@ -127,6 +125,7 @@ mod vstack_impl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<
|
impl<
|
||||||
|
'a,
|
||||||
T: Scalar,
|
T: Scalar,
|
||||||
R1: Dim + DimAdd<Dyn>,
|
R1: Dim + DimAdd<Dyn>,
|
||||||
C1: Dim,
|
C1: Dim,
|
||||||
|
@ -134,17 +133,14 @@ mod vstack_impl {
|
||||||
C2: Dim,
|
C2: Dim,
|
||||||
S2: RawStorage<T, Dyn, C2>,
|
S2: RawStorage<T, Dyn, C2>,
|
||||||
R3: Dim + DimAdd<Dyn>,
|
R3: Dim + DimAdd<Dyn>,
|
||||||
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<T, R1, C1, S1, R3>
|
> Visitor<&Matrix<T, Dyn, C2, S2>> for VStack<'a, T, R1, C1, S1, R3>
|
||||||
where
|
where
|
||||||
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
ShapeConstraint: SameNumberOfColumns<C1, C2>,
|
||||||
{
|
{
|
||||||
type Output = VStack<T, R1, C1, S1, DimSum<R3, Dyn>>;
|
type Output = VStack<'a, T, R1, C1, S1, DimSum<R3, Dyn>>;
|
||||||
fn visit(self, x: &Matrix<T, Dyn, C2, S2>) -> Self::Output {
|
fn visit(self, x: &Matrix<T, Dyn, C2, S2>) -> Self::Output {
|
||||||
let (r2, _) = x.shape_generic();
|
let (r2, _) = x.shape_generic();
|
||||||
let VStack {
|
let VStack { out, current_row } = self;
|
||||||
mut out,
|
|
||||||
current_row,
|
|
||||||
} = self;
|
|
||||||
out.rows_mut(current_row.value(), r2.value())
|
out.rows_mut(current_row.value(), r2.value())
|
||||||
.copy_from::<Dyn, C2, S2>(x);
|
.copy_from::<Dyn, C2, S2>(x);
|
||||||
let current_row = current_row.add(r2);
|
let current_row = current_row.add(r2);
|
||||||
|
@ -161,9 +157,9 @@ mod vstack_impl {
|
||||||
C: Dim,
|
C: Dim,
|
||||||
X: Copy
|
X: Copy
|
||||||
+ VisitTuple<VStackShapeInit, Output = VStackShape<R, C>>
|
+ VisitTuple<VStackShapeInit, Output = VStackShape<R, C>>
|
||||||
+ VisitTuple<
|
+ for<'a> VisitTuple<
|
||||||
VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
||||||
Output = VStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
|
Output = VStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, R>,
|
||||||
>,
|
>,
|
||||||
>(
|
>(
|
||||||
x: X,
|
x: X,
|
||||||
|
@ -173,12 +169,13 @@ mod vstack_impl {
|
||||||
{
|
{
|
||||||
let vstack_shape = VStackShapeInit;
|
let vstack_shape = VStackShapeInit;
|
||||||
let vstack_shape = <X as VisitTuple<_>>::visit(vstack_shape, x);
|
let vstack_shape = <X as VisitTuple<_>>::visit(vstack_shape, x);
|
||||||
|
let mut out = Matrix::zeros_generic(vstack_shape.r, vstack_shape.c);
|
||||||
let vstack_visitor = VStack {
|
let vstack_visitor = VStack {
|
||||||
out: Matrix::zeros_generic(vstack_shape.r, vstack_shape.c),
|
out: &mut out,
|
||||||
current_row: Const,
|
current_row: Const,
|
||||||
};
|
};
|
||||||
let vstack_visitor = <X as VisitTuple<_>>::visit(vstack_visitor, x);
|
let _ = <X as VisitTuple<_>>::visit(vstack_visitor, x);
|
||||||
vstack_visitor.out
|
out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub use vstack_impl::vstack;
|
pub use vstack_impl::vstack;
|
||||||
|
@ -219,12 +216,13 @@ mod hstack_impl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct HStack<T, R, C, S, C2> {
|
pub struct HStack<'a, T, R, C, S, C2> {
|
||||||
out: Matrix<T, R, C, S>,
|
out: &'a mut Matrix<T, R, C, S>,
|
||||||
current_col: C2,
|
current_col: C2,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<
|
impl<
|
||||||
|
'a,
|
||||||
T: Scalar,
|
T: Scalar,
|
||||||
R1: Dim,
|
R1: Dim,
|
||||||
C1: Dim + DimAdd<Const<C2>>,
|
C1: Dim + DimAdd<Const<C2>>,
|
||||||
|
@ -233,17 +231,14 @@ mod hstack_impl {
|
||||||
S2: RawStorage<T, R2, Const<C2>>,
|
S2: RawStorage<T, R2, Const<C2>>,
|
||||||
C3: Dim + DimAdd<Const<C2>>,
|
C3: Dim + DimAdd<Const<C2>>,
|
||||||
const C2: usize,
|
const C2: usize,
|
||||||
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<T, R1, C1, S1, C3>
|
> Visitor<&Matrix<T, R2, Const<C2>, S2>> for HStack<'a, T, R1, C1, S1, C3>
|
||||||
where
|
where
|
||||||
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
||||||
{
|
{
|
||||||
type Output = HStack<T, R1, C1, S1, DimSum<C3, Const<C2>>>;
|
type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Const<C2>>>;
|
||||||
fn visit(self, x: &Matrix<T, R2, Const<C2>, S2>) -> Self::Output {
|
fn visit(self, x: &Matrix<T, R2, Const<C2>, S2>) -> Self::Output {
|
||||||
let (_, c2) = x.shape_generic();
|
let (_, c2) = x.shape_generic();
|
||||||
let HStack {
|
let HStack { out, current_col } = self;
|
||||||
mut out,
|
|
||||||
current_col,
|
|
||||||
} = self;
|
|
||||||
out.fixed_columns_mut::<{ C2 }>(current_col.value())
|
out.fixed_columns_mut::<{ C2 }>(current_col.value())
|
||||||
.copy_from::<R2, Const<C2>, S2>(x);
|
.copy_from::<R2, Const<C2>, S2>(x);
|
||||||
let current_col = current_col.add(c2);
|
let current_col = current_col.add(c2);
|
||||||
|
@ -251,6 +246,7 @@ mod hstack_impl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<
|
impl<
|
||||||
|
'a,
|
||||||
T: Scalar,
|
T: Scalar,
|
||||||
R1: Dim,
|
R1: Dim,
|
||||||
C1: Dim + DimAdd<Dyn>,
|
C1: Dim + DimAdd<Dyn>,
|
||||||
|
@ -258,17 +254,14 @@ mod hstack_impl {
|
||||||
R2: Dim,
|
R2: Dim,
|
||||||
S2: RawStorage<T, R2, Dyn>,
|
S2: RawStorage<T, R2, Dyn>,
|
||||||
C3: Dim + DimAdd<Dyn>,
|
C3: Dim + DimAdd<Dyn>,
|
||||||
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<T, R1, C1, S1, C3>
|
> Visitor<&Matrix<T, R2, Dyn, S2>> for HStack<'a, T, R1, C1, S1, C3>
|
||||||
where
|
where
|
||||||
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
||||||
{
|
{
|
||||||
type Output = HStack<T, R1, C1, S1, DimSum<C3, Dyn>>;
|
type Output = HStack<'a, T, R1, C1, S1, DimSum<C3, Dyn>>;
|
||||||
fn visit(self, x: &Matrix<T, R2, Dyn, S2>) -> Self::Output {
|
fn visit(self, x: &Matrix<T, R2, Dyn, S2>) -> Self::Output {
|
||||||
let (_, c2) = x.shape_generic();
|
let (_, c2) = x.shape_generic();
|
||||||
let HStack {
|
let HStack { out, current_col } = self;
|
||||||
mut out,
|
|
||||||
current_col,
|
|
||||||
} = self;
|
|
||||||
out.columns_mut(current_col.value(), c2.value())
|
out.columns_mut(current_col.value(), c2.value())
|
||||||
.copy_from::<R2, Dyn, S2>(x);
|
.copy_from::<R2, Dyn, S2>(x);
|
||||||
let current_col = current_col.add(c2);
|
let current_col = current_col.add(c2);
|
||||||
|
@ -285,9 +278,9 @@ mod hstack_impl {
|
||||||
C: Dim,
|
C: Dim,
|
||||||
X: Copy
|
X: Copy
|
||||||
+ VisitTuple<HStackShapeInit, Output = HStackShape<R, C>>
|
+ VisitTuple<HStackShapeInit, Output = HStackShape<R, C>>
|
||||||
+ VisitTuple<
|
+ for<'a> VisitTuple<
|
||||||
HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, Const<0>>,
|
||||||
Output = HStack<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
|
Output = HStack<'a, T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer, C>,
|
||||||
>,
|
>,
|
||||||
>(
|
>(
|
||||||
x: X,
|
x: X,
|
||||||
|
@ -297,12 +290,13 @@ mod hstack_impl {
|
||||||
{
|
{
|
||||||
let hstack_shape = HStackShapeInit;
|
let hstack_shape = HStackShapeInit;
|
||||||
let hstack_shape = <X as VisitTuple<_>>::visit(hstack_shape, x);
|
let hstack_shape = <X as VisitTuple<_>>::visit(hstack_shape, x);
|
||||||
|
let mut out = Matrix::zeros_generic(hstack_shape.r, hstack_shape.c);
|
||||||
let hstack_visitor = HStack {
|
let hstack_visitor = HStack {
|
||||||
out: Matrix::zeros_generic(hstack_shape.r, hstack_shape.c),
|
out: &mut out,
|
||||||
current_col: Const,
|
current_col: Const,
|
||||||
};
|
};
|
||||||
let hstack_visitor = <X as VisitTuple<_>>::visit(hstack_visitor, x);
|
let _ = <X as VisitTuple<_>>::visit(hstack_visitor, x);
|
||||||
hstack_visitor.out
|
out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub use hstack_impl::hstack;
|
pub use hstack_impl::hstack;
|
||||||
|
|
Loading…
Reference in New Issue