nalgebra/tests/macros/stack.rs

428 lines
13 KiB
Rust
Raw Permalink Normal View History

Improved stack! implementation, tests (#1375) * Add macro for concatenating matrices * Replace DimUnify with DimEq::representative * Add some simple cat macro output generation tests * Fix formatting in cat macro code * Add random prefix to cat macro output * Add simple quote_spanned for cat macro * Use `generic_view_mut` in cat macro * Fix clippy lints in cat macro * Clean up documentation for cat macro * Remove identity literal from cat macro * Allow references in input to cat macro * Rename cat macro to stack * Add more stack macro tests * Add comment to explain reason for prefix in stack! macro * Refactor matrix!, stack! macros into separate modules * Take all blocks by reference in stack! macro * Make empty stack![] invocation well-defined * Fix stack! macro incorrect reference to data * More extensive tests for stack! macro * Move nalgebra-macros tests to nalgebra tests By testing matrix!, stack! macros etc. in nalgebra, we ensure that these macros are used in the same way that users will be using them. * Fix stack! code generation tests * Add back nalgebra as dev-dependency of nalgebra-macros * Fix accidental wrong matrix! macro references in docs * Rewrite stack! documentation for clarity * Formatting * Skip formatting of macro, rustfmt messes it up * Rewrite stack! impl for improved clarity, Span behavior This improves error messages upon dimension mismatch, among other things. I've also tried to make the implementation easier to understand, adding some comments to help the reader understand the individual steps. * Use SameNumberOfRows/Columns instead of DimEq in stack! macro This gives more accurate compiler errors if matrix dimensions are mismatched. * Check that stack! panics at runtime for basic dimension mismatch * Add suggested edge cases from initial PR to tests * stack! impl: use fixed prefix everywhere This ensures that the expected generated code in tests is the actual generated code when used in the wild. * nalgebra-macros: Remove clippy pedantic, fix clippy complaints pedantic seems to be mostly intent on wasting the programmer's time * Add stack! sanity tests for built-ins and Complex * Fix formatting in test * Improve readability of format_ident! calls in stack! impl * fix trybuild tests * chore: run tests with a specific rust version * More trybuild fixes --------- Co-authored-by: Birk Tjelmeland <git@birktj.no> Co-authored-by: Sébastien Crozet <sebcrozet@dimforge.com>
2024-06-23 17:29:28 +08:00
use crate::macros::assert_eq_and_type;
use cool_asserts::assert_panics;
use na::VecStorage;
use nalgebra::dimension::U1;
use nalgebra::{dmatrix, matrix, stack};
use nalgebra::{
DMatrix, DMatrixView, Dyn, Matrix, Matrix2, Matrix4, OMatrix, SMatrix, SMatrixView,
SMatrixViewMut, Scalar, U2,
};
use nalgebra_macros::vector;
use num_traits::Zero;
/// Simple implementation that stacks dynamic matrices.
///
/// Used for verifying results of the stack! macro. `None` entries are considered to represent
/// a zero block.
fn stack_dyn<T: Scalar + Zero>(blocks: DMatrix<Option<DMatrix<T>>>) -> DMatrix<T> {
let row_counts: Vec<usize> = blocks
.row_iter()
.map(|block_row| {
block_row
.iter()
.map(|block_or_implicit_zero| {
block_or_implicit_zero.as_ref().map(|block| block.nrows())
})
.reduce(|nrows1, nrows2| match (nrows1, nrows2) {
(Some(_), None) => nrows1,
(None, Some(_)) => nrows2,
(None, None) => None,
(Some(nrows1), Some(nrows2)) if nrows1 == nrows2 => Some(nrows1),
_ => panic!("Number of rows must be consistent in each block row"),
})
.unwrap_or(Some(0))
.expect("Each block row must have at least one entry which is not a zero literal")
})
.collect();
let col_counts: Vec<usize> = blocks
.column_iter()
.map(|block_col| {
block_col
.iter()
.map(|block_or_implicit_zero| {
block_or_implicit_zero.as_ref().map(|block| block.ncols())
})
.reduce(|ncols1, ncols2| match (ncols1, ncols2) {
(Some(_), None) => ncols1,
(None, Some(_)) => ncols2,
(None, None) => None,
(Some(ncols1), Some(ncols2)) if ncols1 == ncols2 => Some(ncols1),
_ => panic!("Number of columns must be consistent in each block column"),
})
.unwrap_or(Some(0))
.expect(
"Each block column must have at least one entry which is not a zero literal",
)
})
.collect();
let nrows_total = row_counts.iter().sum();
let ncols_total = col_counts.iter().sum();
let mut output = DMatrix::zeros(nrows_total, ncols_total);
let mut col_offset = 0;
for j in 0..blocks.ncols() {
let mut row_offset = 0;
for i in 0..blocks.nrows() {
if let Some(input_ij) = &blocks[(i, j)] {
let (block_nrows, block_ncols) = input_ij.shape();
output
.view_mut((row_offset, col_offset), (block_nrows, block_ncols))
.copy_from(&input_ij);
}
row_offset += row_counts[i];
}
col_offset += col_counts[j];
}
output
}
macro_rules! stack_dyn_convert_to_dmatrix_option {
(0) => {
None
};
($entry:expr) => {
Some($entry.as_view::<Dyn, Dyn, U1, Dyn>().clone_owned())
};
}
/// Helper macro that compares the result of stack! with a simplified implementation that
/// works only with heap-allocated data.
///
/// This implementation is essentially radically different to the implementation in stack!,
/// so if they both match, then it's a good sign that the stack! impl is correct.
macro_rules! verify_stack {
($matrix_type:ty ; [$($($entry:expr),*);*]) => {
{
// Our input has the same syntax as the stack! macro (and matrix! macro, for that matter)
let stack_result: $matrix_type = stack![$($($entry),*);*];
// Use the dmatrix! macro to nest matrices into each other
let dyn_result = stack_dyn(
dmatrix![$($(stack_dyn_convert_to_dmatrix_option!($entry)),*);*]
);
// println!("{}", stack_result);
// println!("{}", dyn_result);
assert_eq!(stack_result, dyn_result);
}
}
}
#[test]
fn stack_simple() {
let m = stack![
Matrix2::<usize>::identity(), 0;
0, &Matrix2::identity();
];
assert_eq_and_type!(m, Matrix4::identity());
}
#[test]
fn stack_diag() {
let m = stack![
0, matrix![1, 2; 3, 4;];
matrix![5, 6; 7, 8;], 0;
];
let res = matrix![
0, 0, 1, 2;
0, 0, 3, 4;
5, 6, 0, 0;
7, 8, 0, 0;
];
assert_eq_and_type!(m, res);
}
#[test]
fn stack_dynamic() {
let m = stack![
matrix![ 1, 2; 3, 4; ], 0;
0, dmatrix![7, 8, 9; 10, 11, 12; ];
];
let res = dmatrix![
1, 2, 0, 0, 0;
3, 4, 0, 0, 0;
0, 0, 7, 8, 9;
0, 0, 10, 11, 12;
];
assert_eq_and_type!(m, res);
}
#[test]
fn stack_nested() {
let m = stack![
stack![ matrix![1, 2; 3, 4;]; matrix![5, 6;]],
stack![ matrix![7;9;10;], matrix![11; 12; 13;] ];
];
let res = matrix![
1, 2, 7, 11;
3, 4, 9, 12;
5, 6, 10, 13;
];
assert_eq_and_type!(m, res);
}
#[test]
fn stack_single() {
let a = matrix![1, 2; 3, 4];
let b = stack![a];
assert_eq_and_type!(a, b);
}
#[test]
fn stack_single_row() {
let a = matrix![1, 2; 3, 4];
let m = stack![a, a];
let res = matrix![
1, 2, 1, 2;
3, 4, 3, 4;
];
assert_eq_and_type!(m, res);
}
#[test]
fn stack_single_col() {
let a = matrix![1, 2; 3, 4];
let m = stack![a; a];
let res = matrix![
1, 2;
3, 4;
1, 2;
3, 4;
];
assert_eq_and_type!(m, res);
}
#[test]
#[rustfmt::skip]
fn stack_expr() {
let a = matrix![1, 2; 3, 4];
let b = matrix![5, 6; 7, 8];
let m = stack![a + b; 2i32 * b - a];
let res = matrix![
6, 8;
10, 12;
9, 10;
11, 12;
];
assert_eq_and_type!(m, res);
}
#[test]
fn stack_edge_cases() {
{
// Empty stack should return zero matrix with specified type
let _: SMatrix<i32, 0, 0> = stack![];
let _: SMatrix<f64, 0, 0> = stack![];
}
{
// Case suggested by @tpdickso: https://github.com/dimforge/nalgebra/pull/1080#discussion_r1435871752
let a = matrix![1, 2;
3, 4];
let b = DMatrix::from_data(VecStorage::new(Dyn(2), Dyn(0), vec![]));
assert_eq!(
stack![a, 0;
0, b],
matrix![1, 2;
3, 4;
0, 0;
0, 0]
);
}
}
#[rustfmt::skip]
#[test]
fn stack_many_tests() {
// s prefix means static, d prefix means dynamic
// Static matrices
let s_0x0: SMatrix<i32, 0, 0> = matrix![];
let s_0x1: SMatrix<i32, 0, 1> = Matrix::default();
let s_1x0: SMatrix<i32, 1, 0> = Matrix::default();
let s_1x1: SMatrix<i32, 1, 1> = matrix![1];
let s_2x2: SMatrix<i32, 2, 2> = matrix![6, 7; 8, 9];
let s_2x3: SMatrix<i32, 2, 3> = matrix![16, 17, 18; 19, 20, 21];
let s_3x3: SMatrix<i32, 3, 3> = matrix![28, 29, 30; 31, 32, 33; 34, 35, 36];
// Dynamic matrices
let d_0x0: DMatrix<i32> = dmatrix![];
let d_1x2: DMatrix<i32> = dmatrix![9, 10];
let d_2x2: DMatrix<i32> = dmatrix![5, 6; 7, 8];
let d_4x4: DMatrix<i32> = dmatrix![10, 11, 12, 13; 14, 15, 16, 17; 18, 19, 20, 21; 22, 23, 24, 25];
// Check for weirdness with matrices that have zero row/cols
verify_stack!(SMatrix<_, 0, 0>; [s_0x0]);
verify_stack!(SMatrix<_, 0, 1>; [s_0x1]);
verify_stack!(SMatrix<_, 1, 0>; [s_1x0]);
verify_stack!(SMatrix<_, 0, 0>; [s_0x0; s_0x0]);
verify_stack!(SMatrix<_, 0, 0>; [s_0x0, s_0x0; s_0x0, s_0x0]);
verify_stack!(SMatrix<_, 0, 2>; [s_0x1, s_0x1]);
verify_stack!(SMatrix<_, 2, 0>; [s_1x0; s_1x0]);
verify_stack!(SMatrix<_, 1, 0>; [s_1x0, s_1x0]);
verify_stack!(DMatrix<_>; [d_0x0]);
// Horizontal stacking
verify_stack!(SMatrix<_, 1, 2>; [s_1x1, s_1x1]);
verify_stack!(SMatrix<_, 2, 4>; [s_2x2, s_2x2]);
verify_stack!(DMatrix<_>; [d_1x2, d_1x2]);
// Vertical stacking
verify_stack!(SMatrix<_, 2, 1>; [s_1x1; s_1x1]);
verify_stack!(SMatrix<_, 4, 2>; [s_2x2; s_2x2]);
verify_stack!(DMatrix<_>; [d_2x2; d_2x2]);
// Mix static and dynamic matrices
verify_stack!(OMatrix<_, U2, Dyn>; [s_2x2, d_2x2]);
verify_stack!(OMatrix<_, Dyn, U2>; [s_2x2; d_1x2]);
// Stack more than two matrices
verify_stack!(SMatrix<_, 1, 3>; [s_1x1, s_1x1, s_1x1]);
verify_stack!(DMatrix<_>; [d_1x2, d_1x2, d_1x2]);
// Slightly larger dims
verify_stack!(SMatrix<_, 3, 6>; [s_3x3, s_3x3]);
verify_stack!(DMatrix<_>; [d_4x4; d_4x4]);
verify_stack!(SMatrix<_, 4, 7>; [s_2x2, s_2x3, d_2x2;
d_2x2, s_2x3, s_2x2]);
// Mix of references and owned
verify_stack!(OMatrix<_, Dyn, U2>; [&s_2x2; &d_1x2]);
verify_stack!(SMatrix<_, 4, 7>; [ s_2x2, &s_2x3, d_2x2;
&d_2x2, s_2x3, &s_2x2]);
// Views
let s_2x2_v: SMatrixView<_, 2, 2> = s_2x2.as_view();
let s_2x3_v: SMatrixView<_, 2, 3> = s_2x3.as_view();
let d_2x2_v: DMatrixView<_> = d_2x2.as_view();
let mut s_2x2_vm = s_2x2.clone();
let s_2x2_vm: SMatrixViewMut<_, 2, 2> = s_2x2_vm.as_view_mut();
let mut s_2x3_vm = s_2x3.clone();
let s_2x3_vm: SMatrixViewMut<_, 2, 3> = s_2x3_vm.as_view_mut();
verify_stack!(SMatrix<_, 4, 7>; [ s_2x2_vm, &s_2x3_vm, d_2x2_v;
&d_2x2_v, s_2x3_v, &s_2x2_v]);
// Expressions
let matrix_fn = |matrix: &DMatrix<_>| matrix.map(|x_ij| x_ij * 3);
verify_stack!(SMatrix<_, 2, 5>; [ 2 * s_2x2 - 3 * &d_2x2, s_2x3 + 2 * s_2x3]);
verify_stack!(DMatrix<_>; [ 2 * matrix_fn(&d_2x2) ]);
verify_stack!(SMatrix<_, 2, 5>; [ (|matrix| 4 * matrix)(s_2x2), s_2x3 ]);
}
#[test]
fn stack_trybuild_tests() {
let t = trybuild::TestCases::new();
// Verify error message when a row or column only contains a zero entry
t.compile_fail("tests/macros/trybuild/stack_empty_row.rs");
t.compile_fail("tests/macros/trybuild/stack_empty_col.rs");
t.compile_fail("tests/macros/trybuild/stack_incompatible_block_dimensions.rs");
t.compile_fail("tests/macros/trybuild/stack_incompatible_block_dimensions2.rs");
}
#[test]
fn stack_mismatched_dimensions_runtime_panics() {
// s prefix denotes static, d dynamic
let s_2x2 = matrix![1, 2; 3, 4];
let d_2x3 = dmatrix![5, 6, 7; 8, 9, 10];
let d_1x2 = dmatrix![11, 12];
let d_1x3 = dmatrix![13, 14, 15];
assert_panics!(
stack![s_2x2, d_1x2],
includes("All blocks in block row 0 must have the same number of rows")
);
assert_panics!(
stack![s_2x2; d_2x3],
includes("All blocks in block column 0 must have the same number of columns")
);
assert_panics!(
stack![s_2x2, s_2x2; d_1x2, d_2x3],
includes("All blocks in block row 1 must have the same number of rows")
);
assert_panics!(
stack![s_2x2, s_2x2; d_1x2, d_1x3],
includes("All blocks in block column 1 must have the same number of columns")
);
assert_panics!(
{
// Edge case suggested by @tpdickso: https://github.com/dimforge/nalgebra/pull/1080#discussion_r1435871752
let d_3x0 = DMatrix::from_data(VecStorage::new(Dyn(3), Dyn(0), Vec::<i32>::new()));
stack![s_2x2, d_3x0]
},
includes("All blocks in block row 0 must have the same number of rows")
);
}
#[test]
fn stack_test_builtin_types() {
// Other than T: Zero, there's nothing type-specific in the logic for stack!
// These tests are just sanity tests, to make sure it works with the common built-in types
let a = matrix![1, 2; 3, 4];
let b = vector![5, 6];
let c = matrix![7, 8];
let expected = matrix![ 1, 2, 5;
3, 4, 6;
7, 8, 0 ];
macro_rules! check_builtin {
($T:ty) => {{
// Cannot use .cast::<$T> because we cannot convert between unsigned and signed
let stacked = stack![a.map(|a_ij| a_ij as $T), b.map(|b_ij| b_ij as $T);
c.map(|c_ij| c_ij as $T), 0];
assert_eq!(stacked, expected.map(|e_ij| e_ij as $T));
}}
}
check_builtin!(i8);
check_builtin!(i16);
check_builtin!(i32);
check_builtin!(i64);
check_builtin!(i128);
check_builtin!(u8);
check_builtin!(u16);
check_builtin!(u32);
check_builtin!(u64);
check_builtin!(u128);
check_builtin!(f32);
check_builtin!(f64);
}
#[test]
fn stack_test_complex() {
use num_complex::Complex as C;
type C32 = C<f32>;
let a = matrix![C::new(1.0, 1.0), C::new(2.0, 2.0); C::new(3.0, 3.0), C::new(4.0, 4.0)];
let b = vector![C::new(5.0, 5.0), C::new(6.0, 6.0)];
let c = matrix![C::new(7.0, 7.0), C::new(8.0, 8.0)];
let expected = matrix![ 1, 2, 5;
3, 4, 6;
7, 8, 0 ]
.map(|x| C::new(x as f64, x as f64));
assert_eq!(stack![a, b; c, 0], expected);
assert_eq!(
stack![a.cast::<C32>(), b.cast::<C32>(); c.cast::<C32>(), 0],
expected.cast::<C32>()
);
}