ReshapableStorage for slices + tests for owned reshape

In the process of implementing ReshapbleStorage for SliceStorage(Mut),
I discovered that there appears to be no tests for the existing
reshape_generic functionality on owned matrices.
This commit is contained in:
Andreas Longva 2022-11-24 14:14:49 +01:00 committed by Sébastien Crozet
parent 4221c44a2b
commit afabf4bad2
3 changed files with 126 additions and 0 deletions

View File

@ -9,6 +9,7 @@ use crate::base::iter::MatrixIter;
use crate::base::storage::{IsContiguous, Owned, RawStorage, RawStorageMut, Storage};
use crate::base::{Matrix, Scalar};
use crate::constraint::{DimEq, ShapeConstraint};
use crate::ReshapableStorage;
macro_rules! view_storage_impl (
($doc: expr; $Storage: ident as $SRef: ty; $legacy_name:ident => $T: ident.$get_addr: ident ($Ptr: ty as $Ref: ty)) => {
@ -1186,3 +1187,47 @@ where
self.into()
}
}
// TODO: Arbitrary strides?
impl<'a, T, R1, C1, R2, C2> ReshapableStorage<T, R1, C1, R2, C2>
for ViewStorage<'a, T, R1, C1, U1, R1>
where
T: Scalar,
R1: Dim,
C1: Dim,
R2: Dim,
C2: Dim,
{
type Output = ViewStorage<'a, T, R2, C2, U1, R2>;
fn reshape_generic(self, nrows: R2, ncols: C2) -> Self::Output {
let (r1, c1) = self.shape();
assert_eq!(nrows.value() * ncols.value(), r1.value() * c1.value());
let ptr = self.ptr();
let new_shape = (nrows, ncols);
let strides = (U1::name(), nrows);
unsafe { ViewStorage::from_raw_parts(ptr, new_shape, strides) }
}
}
// TODO: Arbitrary strides?
impl<'a, T, R1, C1, R2, C2> ReshapableStorage<T, R1, C1, R2, C2>
for ViewStorageMut<'a, T, R1, C1, U1, R1>
where
T: Scalar,
R1: Dim,
C1: Dim,
R2: Dim,
C2: Dim,
{
type Output = ViewStorageMut<'a, T, R2, C2, U1, R2>;
fn reshape_generic(mut self, nrows: R2, ncols: C2) -> Self::Output {
let (r1, c1) = self.shape();
assert_eq!(nrows.value() * ncols.value(), r1.value() * c1.value());
let ptr = self.ptr_mut();
let new_shape = (nrows, ncols);
let strides = (U1::name(), nrows);
unsafe { ViewStorageMut::from_raw_parts(ptr, new_shape, strides) }
}
}

View File

@ -7,6 +7,7 @@ mod matrix;
mod matrix_view;
#[cfg(feature = "mint")]
mod mint;
mod reshape;
#[cfg(feature = "rkyv-serialize-no-std")]
mod rkyv;
mod serde;

80
tests/core/reshape.rs Normal file
View File

@ -0,0 +1,80 @@
use na::{
Const, DMatrix, DMatrixSlice, DMatrixSliceMut, Dyn, Dynamic, Matrix, MatrixSlice,
MatrixSliceMut, SMatrix, SMatrixSlice, SMatrixSliceMut, VecStorage, U3, U4,
};
use nalgebra_macros::matrix;
use simba::scalar::SupersetOf;
const MATRIX: SMatrix<i32, 4, 3> = matrix![
1, 2, 3;
4, 5, 6;
7, 8, 9;
10, 11, 12
];
const RESHAPED_MATRIX: SMatrix<i32, 3, 4> = matrix![
1, 10, 8, 6;
4, 2, 11, 9;
7, 5, 3, 12
];
// Helper alias for making it easier to specify dynamically allocated matrices with
// different dimension types (unlike DMatrix)
type GenericDMatrix<T, R, C> = Matrix<T, R, C, VecStorage<T, R, C>>;
#[test]
fn reshape_owned() {
macro_rules! test_reshape {
($in_matrix:ty => $out_matrix:ty, $rows:expr, $cols:expr) => {{
// This is a pretty weird way to convert, but Matrix implements SubsetOf
let matrix: $in_matrix = MATRIX.to_subset().unwrap();
let reshaped: $out_matrix = matrix.reshape_generic($rows, $cols);
assert_eq!(reshaped, RESHAPED_MATRIX);
}};
}
test_reshape!(SMatrix<_, 4, 3> => SMatrix<_, 3, 4>, U3, U4);
test_reshape!(GenericDMatrix<_, U4, Dyn> => GenericDMatrix<_, Dyn, Dyn>, Dyn(3), Dyn(4));
test_reshape!(GenericDMatrix<_, U4, Dyn> => GenericDMatrix<_, U3, Dyn>, U3, Dyn(4));
test_reshape!(GenericDMatrix<_, U4, Dyn> => GenericDMatrix<_, Dyn, U4>, Dyn(3), U4);
test_reshape!(DMatrix<_> => DMatrix<_>, Dyn(3), Dyn(4));
}
#[test]
fn reshape_slice() {
macro_rules! test_reshape {
($in_slice:ty => $out_slice:ty, $rows:expr, $cols:expr) => {
// We test both that types check out by being explicit about types
// and the actual contents of the matrix
{
// By constructing the slice from a mutable reference we can obtain *either*
// an immutable slice or a mutable slice, which simplifies the testing of both
// types of mutability
let mut source_matrix = MATRIX.clone();
let slice: $in_slice = Matrix::from(&mut source_matrix);
let reshaped: $out_slice = slice.reshape_generic($rows, $cols);
assert_eq!(reshaped, RESHAPED_MATRIX);
}
};
}
// Static "source slice"
test_reshape!(SMatrixSlice<_, 4, 3> => SMatrixSlice<_, 3, 4>, U3, U4);
test_reshape!(SMatrixSlice<_, 4, 3> => DMatrixSlice<_>, Dynamic::new(3), Dynamic::new(4));
test_reshape!(SMatrixSlice<_, 4, 3> => MatrixSlice<_, Const<3>, Dynamic>, U3, Dynamic::new(4));
test_reshape!(SMatrixSlice<_, 4, 3> => MatrixSlice<_, Dynamic, Const<4>>, Dynamic::new(3), U4);
test_reshape!(SMatrixSliceMut<_, 4, 3> => SMatrixSliceMut<_, 3, 4>, U3, U4);
test_reshape!(SMatrixSliceMut<_, 4, 3> => DMatrixSliceMut<_>, Dynamic::new(3), Dynamic::new(4));
test_reshape!(SMatrixSliceMut<_, 4, 3> => MatrixSliceMut<_, Const<3>, Dynamic>, U3, Dynamic::new(4));
test_reshape!(SMatrixSliceMut<_, 4, 3> => MatrixSliceMut<_, Dynamic, Const<4>>, Dynamic::new(3), U4);
// Dynamic "source slice"
test_reshape!(DMatrixSlice<_> => SMatrixSlice<_, 3, 4>, U3, U4);
test_reshape!(DMatrixSlice<_> => DMatrixSlice<_>, Dynamic::new(3), Dynamic::new(4));
test_reshape!(DMatrixSlice<_> => MatrixSlice<_, Const<3>, Dynamic>, U3, Dynamic::new(4));
test_reshape!(DMatrixSlice<_> => MatrixSlice<_, Dynamic, Const<4>>, Dynamic::new(3), U4);
test_reshape!(DMatrixSliceMut<_> => SMatrixSliceMut<_, 3, 4>, U3, U4);
test_reshape!(DMatrixSliceMut<_> => DMatrixSliceMut<_>, Dynamic::new(3), Dynamic::new(4));
test_reshape!(DMatrixSliceMut<_> => MatrixSliceMut<_, Const<3>, Dynamic>, U3, Dynamic::new(4));
test_reshape!(DMatrixSliceMut<_> => MatrixSliceMut<_, Dynamic, Const<4>>, Dynamic::new(3), U4);
}