From eb228faa2b0ee115d3e444d69227a61d01f799c4 Mon Sep 17 00:00:00 2001 From: Andreas Borgen Longva Date: Sun, 23 Jun 2024 11:29:28 +0200 Subject: [PATCH] Improved stack! implementation, tests (#1375) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add macro for concatenating matrices * Replace DimUnify with DimEq::representative * Add some simple cat macro output generation tests * Fix formatting in cat macro code * Add random prefix to cat macro output * Add simple quote_spanned for cat macro * Use `generic_view_mut` in cat macro * Fix clippy lints in cat macro * Clean up documentation for cat macro * Remove identity literal from cat macro * Allow references in input to cat macro * Rename cat macro to stack * Add more stack macro tests * Add comment to explain reason for prefix in stack! macro * Refactor matrix!, stack! macros into separate modules * Take all blocks by reference in stack! macro * Make empty stack![] invocation well-defined * Fix stack! macro incorrect reference to data * More extensive tests for stack! macro * Move nalgebra-macros tests to nalgebra tests By testing matrix!, stack! macros etc. in nalgebra, we ensure that these macros are used in the same way that users will be using them. * Fix stack! code generation tests * Add back nalgebra as dev-dependency of nalgebra-macros * Fix accidental wrong matrix! macro references in docs * Rewrite stack! documentation for clarity * Formatting * Skip formatting of macro, rustfmt messes it up * Rewrite stack! impl for improved clarity, Span behavior This improves error messages upon dimension mismatch, among other things. I've also tried to make the implementation easier to understand, adding some comments to help the reader understand the individual steps. * Use SameNumberOfRows/Columns instead of DimEq in stack! macro This gives more accurate compiler errors if matrix dimensions are mismatched. * Check that stack! panics at runtime for basic dimension mismatch * Add suggested edge cases from initial PR to tests * stack! impl: use fixed prefix everywhere This ensures that the expected generated code in tests is the actual generated code when used in the wild. * nalgebra-macros: Remove clippy pedantic, fix clippy complaints pedantic seems to be mostly intent on wasting the programmer's time * Add stack! sanity tests for built-ins and Complex * Fix formatting in test * Improve readability of format_ident! calls in stack! impl * fix trybuild tests * chore: run tests with a specific rust version * More trybuild fixes --------- Co-authored-by: Birk Tjelmeland Co-authored-by: Sébastien Crozet --- .github/workflows/nalgebra-ci-build.yml | 73 +-- Cargo.toml | 6 +- nalgebra-macros/Cargo.toml | 3 +- nalgebra-macros/src/lib.rs | 286 +++++------- nalgebra-macros/src/matrix_vector_impl.rs | 201 +++++++++ nalgebra-macros/src/stack_impl.rs | 302 +++++++++++++ src/base/constraint.rs | 17 + src/lib.rs | 2 +- tests/lib.rs | 3 + .../tests/tests.rs => tests/macros/matrix.rs | 17 +- tests/macros/mod.rs | 19 + tests/macros/stack.rs | 427 ++++++++++++++++++ .../trybuild/dmatrix_mismatched_dimensions.rs | 2 +- .../dmatrix_mismatched_dimensions.stderr | 2 +- .../trybuild/matrix_mismatched_dimensions.rs | 2 +- .../matrix_mismatched_dimensions.stderr | 2 +- tests/macros/trybuild/stack_empty_col.rs | 6 + tests/macros/trybuild/stack_empty_col.stderr | 7 + tests/macros/trybuild/stack_empty_row.rs | 6 + tests/macros/trybuild/stack_empty_row.stderr | 7 + .../stack_incompatible_block_dimensions.rs | 13 + ...stack_incompatible_block_dimensions.stderr | 37 ++ .../stack_incompatible_block_dimensions2.rs | 14 + ...tack_incompatible_block_dimensions2.stderr | 37 ++ 24 files changed, 1263 insertions(+), 228 deletions(-) create mode 100644 nalgebra-macros/src/matrix_vector_impl.rs create mode 100644 nalgebra-macros/src/stack_impl.rs rename nalgebra-macros/tests/tests.rs => tests/macros/matrix.rs (96%) create mode 100644 tests/macros/mod.rs create mode 100644 tests/macros/stack.rs rename {nalgebra-macros/tests => tests/macros}/trybuild/dmatrix_mismatched_dimensions.rs (64%) rename {nalgebra-macros/tests => tests/macros}/trybuild/dmatrix_mismatched_dimensions.stderr (64%) rename {nalgebra-macros/tests => tests/macros}/trybuild/matrix_mismatched_dimensions.rs (65%) rename {nalgebra-macros/tests => tests/macros}/trybuild/matrix_mismatched_dimensions.stderr (65%) create mode 100644 tests/macros/trybuild/stack_empty_col.rs create mode 100644 tests/macros/trybuild/stack_empty_col.stderr create mode 100644 tests/macros/trybuild/stack_empty_row.rs create mode 100644 tests/macros/trybuild/stack_empty_row.stderr create mode 100644 tests/macros/trybuild/stack_incompatible_block_dimensions.rs create mode 100644 tests/macros/trybuild/stack_incompatible_block_dimensions.stderr create mode 100644 tests/macros/trybuild/stack_incompatible_block_dimensions2.rs create mode 100644 tests/macros/trybuild/stack_incompatible_block_dimensions2.stderr diff --git a/.github/workflows/nalgebra-ci-build.yml b/.github/workflows/nalgebra-ci-build.yml index 97b42184..ff6bff00 100644 --- a/.github/workflows/nalgebra-ci-build.yml +++ b/.github/workflows/nalgebra-ci-build.yml @@ -13,37 +13,37 @@ jobs: check-fmt: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Check formatting - run: cargo fmt -- --check + - uses: actions/checkout@v2 + - name: Check formatting + run: cargo fmt -- --check clippy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Install clippy - run: rustup component add clippy - - name: Run clippy - run: cargo clippy + - uses: actions/checkout@v2 + - name: Install clippy + run: rustup component add clippy + - name: Run clippy + run: cargo clippy build-nalgebra: runs-on: ubuntu-latest -# env: -# RUSTFLAGS: -D warnings + # env: + # RUSTFLAGS: -D warnings steps: - - uses: actions/checkout@v2 - - name: Build --no-default-feature - run: cargo build --no-default-features; - - name: Build (default features) - run: cargo build; - - name: Build --features serde-serialize - run: cargo build --features serde-serialize - - name: Build nalgebra-lapack - run: cd nalgebra-lapack; cargo build; - - name: Build nalgebra-sparse --no-default-features - run: cd nalgebra-sparse; cargo build --no-default-features; - - name: Build nalgebra-sparse (default features) - run: cd nalgebra-sparse; cargo build; - - name: Build nalgebra-sparse --all-features - run: cd nalgebra-sparse; cargo build --all-features; + - uses: actions/checkout@v2 + - name: Build --no-default-feature + run: cargo build --no-default-features; + - name: Build (default features) + run: cargo build; + - name: Build --features serde-serialize + run: cargo build --features serde-serialize + - name: Build nalgebra-lapack + run: cd nalgebra-lapack; cargo build; + - name: Build nalgebra-sparse --no-default-features + run: cd nalgebra-sparse; cargo build --no-default-features; + - name: Build nalgebra-sparse (default features) + run: cd nalgebra-sparse; cargo build; + - name: Build nalgebra-sparse --all-features + run: cd nalgebra-sparse; cargo build --all-features; # Run this on it’s own job because it alone takes a lot of time. # So it’s best to let it run in parallel to the other jobs. build-nalgebra-all-features: @@ -54,9 +54,18 @@ jobs: - run: cargo build -p nalgebra-glm --all-features; test-nalgebra: runs-on: ubuntu-latest -# env: -# RUSTFLAGS: -D warnings + # env: + # RUSTFLAGS: -D warnings steps: + # Tests are run with a specific version of the compiler to avoid + # trybuild errors when a new compiler version is out. This can be + # bumped as needed after running the tests with TRYBUILD=overwrite + # to re-generate the error reference. + - name: Select rustc version + uses: actions-rs/toolchain@v1 + with: + toolchain: 1.79.0 + override: true - uses: actions/checkout@v2 - name: test run: cargo test --features arbitrary,rand,serde-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests,rkyv-safe-deser,rayon; @@ -85,8 +94,8 @@ jobs: run: cargo test -p nalgebra-macros build-wasm: runs-on: ubuntu-latest -# env: -# RUSTFLAGS: -D warnings + # env: + # RUSTFLAGS: -D warnings steps: - uses: actions/checkout@v2 - run: rustup target add wasm32-unknown-unknown @@ -121,6 +130,6 @@ jobs: docs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Generate documentation - run: cargo doc + - uses: actions/checkout@v2 + - name: Generate documentation + run: cargo doc diff --git a/Cargo.toml b/Cargo.toml index a4548c26..3cda80a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,6 @@ libm = ["simba/libm"] libm-force = ["simba/libm_force"] macros = ["nalgebra-macros"] - # Conversion convert-mint = ["mint"] convert-bytemuck = ["bytemuck"] @@ -122,6 +121,11 @@ nalgebra = { path = ".", features = ["debug", "compare", "rand", "macros"] } matrixcompare = "0.3.0" itertools = "0.13" +# For macro testing +trybuild = "1.0.90" + +cool_asserts = "2.0.3" + [workspace] members = ["nalgebra-lapack", "nalgebra-glm", "nalgebra-sparse", "nalgebra-macros"] resolver = "2" diff --git a/nalgebra-macros/Cargo.toml b/nalgebra-macros/Cargo.toml index 6d08ed49..52cb7405 100644 --- a/nalgebra-macros/Cargo.toml +++ b/nalgebra-macros/Cargo.toml @@ -21,5 +21,4 @@ quote = "1.0" proc-macro2 = "1.0" [dev-dependencies] -nalgebra = { version = "0.32.0", path = ".." } -trybuild = "1.0.42" +nalgebra = { version = "0.32.1", path = ".." } diff --git a/nalgebra-macros/src/lib.rs b/nalgebra-macros/src/lib.rs index 827d6080..e209ef68 100644 --- a/nalgebra-macros/src/lib.rs +++ b/nalgebra-macros/src/lib.rs @@ -12,102 +12,19 @@ future_incompatible, missing_copy_implementations, missing_debug_implementations, - clippy::all, - clippy::pedantic + clippy::all )] +mod matrix_vector_impl; +mod stack_impl; + +use matrix_vector_impl::{Matrix, Vector}; + +use crate::matrix_vector_impl::{dmatrix_impl, dvector_impl, matrix_impl, vector_impl}; use proc_macro::TokenStream; -use quote::{quote, ToTokens, TokenStreamExt}; -use syn::parse::{Error, Parse, ParseStream, Result}; -use syn::punctuated::Punctuated; -use syn::Expr; -use syn::{parse_macro_input, Token}; - -use proc_macro2::{Delimiter, Spacing, TokenStream as TokenStream2, TokenTree}; -use proc_macro2::{Group, Punct}; - -struct Matrix { - // Represent the matrix as a row-major vector of vectors of expressions - rows: Vec>, - ncols: usize, -} - -impl Matrix { - fn nrows(&self) -> usize { - self.rows.len() - } - - fn ncols(&self) -> usize { - self.ncols - } - - /// Produces a stream of tokens representing this matrix as a column-major nested array. - fn to_col_major_nested_array_tokens(&self) -> TokenStream2 { - let mut result = TokenStream2::new(); - for j in 0..self.ncols() { - let mut col = TokenStream2::new(); - let col_iter = (0..self.nrows()).map(move |i| &self.rows[i][j]); - col.append_separated(col_iter, Punct::new(',', Spacing::Alone)); - result.append(Group::new(Delimiter::Bracket, col)); - result.append(Punct::new(',', Spacing::Alone)); - } - TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result))) - } - - /// Produces a stream of tokens representing this matrix as a column-major flat array - /// (suitable for representing e.g. a `DMatrix`). - fn to_col_major_flat_array_tokens(&self) -> TokenStream2 { - let mut data = TokenStream2::new(); - for j in 0..self.ncols() { - for i in 0..self.nrows() { - self.rows[i][j].to_tokens(&mut data); - data.append(Punct::new(',', Spacing::Alone)); - } - } - TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data))) - } -} - -type MatrixRowSyntax = Punctuated; - -impl Parse for Matrix { - fn parse(input: ParseStream<'_>) -> Result { - let mut rows = Vec::new(); - let mut ncols = None; - - while !input.is_empty() { - let row_span = input.span(); - let row = MatrixRowSyntax::parse_separated_nonempty(input)?; - - if let Some(ncols) = ncols { - if row.len() != ncols { - let row_idx = rows.len(); - let error_msg = format!( - "Unexpected number of entries in row {}. Expected {}, found {} entries.", - row_idx, - ncols, - row.len() - ); - return Err(Error::new(row_span, error_msg)); - } - } else { - ncols = Some(row.len()); - } - rows.push(row.into_iter().collect()); - - // We've just read a row, so if there are more tokens, there must be a semi-colon, - // otherwise the input is malformed - if !input.is_empty() { - input.parse::()?; - } - } - - Ok(Self { - rows, - ncols: ncols.unwrap_or(0), - }) - } -} +use quote::quote; +use stack_impl::stack_impl; +use syn::parse_macro_input; /// Construct a fixed-size matrix directly from data. /// @@ -145,20 +62,7 @@ impl Parse for Matrix { /// ``` #[proc_macro] pub fn matrix(stream: TokenStream) -> TokenStream { - let matrix = parse_macro_input!(stream as Matrix); - - let row_dim = matrix.nrows(); - let col_dim = matrix.ncols(); - - let array_tokens = matrix.to_col_major_nested_array_tokens(); - - // TODO: Use quote_spanned instead?? - let output = quote! { - nalgebra::SMatrix::<_, #row_dim, #col_dim> - ::from_array_storage(nalgebra::ArrayStorage(#array_tokens)) - }; - - proc_macro::TokenStream::from(output) + matrix_impl(stream) } /// Construct a dynamic matrix directly from data. @@ -180,55 +84,7 @@ pub fn matrix(stream: TokenStream) -> TokenStream { /// ``` #[proc_macro] pub fn dmatrix(stream: TokenStream) -> TokenStream { - let matrix = parse_macro_input!(stream as Matrix); - - let row_dim = matrix.nrows(); - let col_dim = matrix.ncols(); - - let array_tokens = matrix.to_col_major_flat_array_tokens(); - - // TODO: Use quote_spanned instead?? - let output = quote! { - nalgebra::DMatrix::<_> - ::from_vec_storage(nalgebra::VecStorage::new( - nalgebra::Dyn(#row_dim), - nalgebra::Dyn(#col_dim), - vec!#array_tokens)) - }; - - proc_macro::TokenStream::from(output) -} - -struct Vector { - elements: Vec, -} - -impl Vector { - fn to_array_tokens(&self) -> TokenStream2 { - let mut data = TokenStream2::new(); - data.append_separated(&self.elements, Punct::new(',', Spacing::Alone)); - TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data))) - } - - fn len(&self) -> usize { - self.elements.len() - } -} - -impl Parse for Vector { - fn parse(input: ParseStream<'_>) -> Result { - // The syntax of a vector is just the syntax of a single matrix row - if input.is_empty() { - Ok(Self { - elements: Vec::new(), - }) - } else { - let elements = MatrixRowSyntax::parse_terminated(input)? - .into_iter() - .collect(); - Ok(Self { elements }) - } - } + dmatrix_impl(stream) } /// Construct a fixed-size column vector directly from data. @@ -252,14 +108,7 @@ impl Parse for Vector { /// ``` #[proc_macro] pub fn vector(stream: TokenStream) -> TokenStream { - let vector = parse_macro_input!(stream as Vector); - let len = vector.len(); - let array_tokens = vector.to_array_tokens(); - let output = quote! { - nalgebra::SVector::<_, #len> - ::from_array_storage(nalgebra::ArrayStorage([#array_tokens])) - }; - proc_macro::TokenStream::from(output) + vector_impl(stream) } /// Construct a dynamic column vector directly from data. @@ -279,17 +128,7 @@ pub fn vector(stream: TokenStream) -> TokenStream { /// ``` #[proc_macro] pub fn dvector(stream: TokenStream) -> TokenStream { - let vector = parse_macro_input!(stream as Vector); - let len = vector.len(); - let array_tokens = vector.to_array_tokens(); - let output = quote! { - nalgebra::DVector::<_> - ::from_vec_storage(nalgebra::VecStorage::new( - nalgebra::Dyn(#len), - nalgebra::Const::<1>, - vec!#array_tokens)) - }; - proc_macro::TokenStream::from(output) + dvector_impl(stream) } /// Construct a fixed-size point directly from data. @@ -321,3 +160,100 @@ pub fn point(stream: TokenStream) -> TokenStream { }; proc_macro::TokenStream::from(output) } + +/// Construct a new matrix by stacking matrices in a block matrix. +/// +/// **Note: Requires the `macros` feature to be enabled (enabled by default)**. +/// +/// This macro facilitates the construction of +/// [block matrices](https://en.wikipedia.org/wiki/Block_matrix) +/// by stacking blocks (matrices) using the same MATLAB-like syntax as the [`matrix!`] and +/// [`dmatrix!`] macros: +/// +/// ```rust +/// # use nalgebra::stack; +/// # +/// # fn main() { +/// # let [a, b, c, d] = std::array::from_fn(|_| nalgebra::Matrix1::new(0)); +/// // a, b, c and d are matrices +/// let block_matrix = stack![ a, b; +/// c, d ]; +/// # } +/// ``` +/// +/// The resulting matrix is stack-allocated if the dimension of each block row and column +/// can be determined at compile-time, otherwise it is heap-allocated. +/// This is the case if, for every row, there is at least one matrix with a fixed number of rows, +/// and, for every column, there is at least one matrix with a fixed number of columns. +/// +/// [`stack!`] also supports special syntax to indicate zero blocks in a matrix: +/// +/// ```rust +/// # use nalgebra::stack; +/// # +/// # fn main() { +/// # let [a, b, c, d] = std::array::from_fn(|_| nalgebra::Matrix1::new(0)); +/// // a and d are matrices +/// let block_matrix = stack![ a, 0; +/// 0, d ]; +/// # } +/// ``` +/// Here, the `0` literal indicates a zero matrix of implicitly defined size. +/// In order to infer the size of the zero blocks, there must be at least one matrix +/// in every row and column of the matrix. +/// In other words, no row or column can consist entirely of implicit zero blocks. +/// +/// # Panics +/// +/// Panics if dimensions are inconsistent and it cannot be determined at compile-time. +/// +/// # Examples +/// +/// ``` +/// use nalgebra::{matrix, SMatrix, stack}; +/// +/// let a = matrix![1, 2; +/// 3, 4]; +/// let b = matrix![5, 6; +/// 7, 8]; +/// let c = matrix![9, 10]; +/// +/// let block_matrix = stack![ a, b; +/// c, 0 ]; +/// +/// assert_eq!(block_matrix, matrix![1, 2, 5, 6; +/// 3, 4, 7, 8; +/// 9, 10, 0, 0]); +/// +/// // Verify that the resulting block matrix is stack-allocated +/// let _: SMatrix<_, 3, 4> = block_matrix; +/// ``` +/// +/// The example above shows how stacking stack-allocated matrices results in a stack-allocated +/// block matrix. If all row and column dimensions can not be determined at compile-time, +/// the result is instead a dynamically allocated matrix: +/// +/// ``` +/// use nalgebra::{dmatrix, DMatrix, Dyn, matrix, OMatrix, SMatrix, stack, U3}; +/// +/// # let a = matrix![1, 2; 3, 4]; let c = matrix![9, 10]; +/// // a and c as before, but b is a dynamic matrix this time +/// let b = dmatrix![5, 6; +/// 7, 8]; +/// +/// // In this case, the number of rows can be statically inferred to be 3 (U3), +/// // but the number of columns cannot, hence it is dynamic +/// let block_matrix: OMatrix<_, U3, Dyn> = stack![ a, b; +/// c, 0 ]; +/// +/// // If necessary, a fully dynamic matrix (DMatrix) can be obtained by reshaping +/// let dyn_block_matrix: DMatrix<_> = block_matrix.reshape_generic(Dyn(3), Dyn(4)); +/// ``` +/// Note that explicitly annotating the types of `block_matrix` and `dyn_block_matrix` is +/// only made for illustrative purposes, and is not generally necessary. +/// +#[proc_macro] +pub fn stack(stream: TokenStream) -> TokenStream { + let matrix = parse_macro_input!(stream as Matrix); + proc_macro::TokenStream::from(stack_impl(matrix).unwrap_or_else(syn::Error::into_compile_error)) +} diff --git a/nalgebra-macros/src/matrix_vector_impl.rs b/nalgebra-macros/src/matrix_vector_impl.rs new file mode 100644 index 00000000..d4357e88 --- /dev/null +++ b/nalgebra-macros/src/matrix_vector_impl.rs @@ -0,0 +1,201 @@ +use proc_macro::TokenStream; +use quote::{quote, ToTokens, TokenStreamExt}; +use std::ops::Index; +use syn::parse::{Error, Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::Expr; +use syn::{parse_macro_input, Token}; + +use proc_macro2::{Delimiter, Spacing, TokenStream as TokenStream2, TokenTree}; +use proc_macro2::{Group, Punct}; + +/// A matrix of expressions +pub struct Matrix { + // Represent the matrix data in row-major format + data: Vec, + nrows: usize, + ncols: usize, +} + +impl Index<(usize, usize)> for Matrix { + type Output = Expr; + + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + let linear_idx = self.ncols * row + col; + &self.data[linear_idx] + } +} + +impl Matrix { + pub fn nrows(&self) -> usize { + self.nrows + } + + pub fn ncols(&self) -> usize { + self.ncols + } + + /// Produces a stream of tokens representing this matrix as a column-major nested array. + pub fn to_col_major_nested_array_tokens(&self) -> TokenStream2 { + let mut result = TokenStream2::new(); + for j in 0..self.ncols() { + let mut col = TokenStream2::new(); + let col_iter = (0..self.nrows()).map(|i| &self[(i, j)]); + col.append_separated(col_iter, Punct::new(',', Spacing::Alone)); + result.append(Group::new(Delimiter::Bracket, col)); + result.append(Punct::new(',', Spacing::Alone)); + } + TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result))) + } + + /// Produces a stream of tokens representing this matrix as a column-major flat array + /// (suitable for representing e.g. a `DMatrix`). + pub fn to_col_major_flat_array_tokens(&self) -> TokenStream2 { + let mut data = TokenStream2::new(); + for j in 0..self.ncols() { + for i in 0..self.nrows() { + self[(i, j)].to_tokens(&mut data); + data.append(Punct::new(',', Spacing::Alone)); + } + } + TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data))) + } +} + +type MatrixRowSyntax = Punctuated; + +impl Parse for Matrix { + fn parse(input: ParseStream<'_>) -> syn::Result { + let mut data = Vec::new(); + let mut ncols = None; + let mut nrows = 0; + + while !input.is_empty() { + let row = MatrixRowSyntax::parse_separated_nonempty(input)?; + let row_span = row.span(); + + if let Some(ncols) = ncols { + if row.len() != ncols { + let error_msg = format!( + "Unexpected number of entries in row {}. Expected {}, found {} entries.", + nrows, + ncols, + row.len() + ); + return Err(Error::new(row_span, error_msg)); + } + } else { + ncols = Some(row.len()); + } + data.extend(row.into_iter()); + nrows += 1; + + // We've just read a row, so if there are more tokens, there must be a semi-colon, + // otherwise the input is malformed + if !input.is_empty() { + input.parse::()?; + } + } + + Ok(Self { + data, + nrows, + ncols: ncols.unwrap_or(0), + }) + } +} + +pub struct Vector { + elements: Vec, +} + +impl Vector { + pub fn to_array_tokens(&self) -> TokenStream2 { + let mut data = TokenStream2::new(); + data.append_separated(&self.elements, Punct::new(',', Spacing::Alone)); + TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data))) + } + + pub fn len(&self) -> usize { + self.elements.len() + } +} + +impl Parse for Vector { + fn parse(input: ParseStream<'_>) -> syn::Result { + // The syntax of a vector is just the syntax of a single matrix row + if input.is_empty() { + Ok(Self { + elements: Vec::new(), + }) + } else { + let elements = MatrixRowSyntax::parse_terminated(input)? + .into_iter() + .collect(); + Ok(Self { elements }) + } + } +} + +pub fn matrix_impl(stream: TokenStream) -> TokenStream { + let matrix = parse_macro_input!(stream as Matrix); + + let row_dim = matrix.nrows(); + let col_dim = matrix.ncols(); + + let array_tokens = matrix.to_col_major_nested_array_tokens(); + + // TODO: Use quote_spanned instead?? + let output = quote! { + nalgebra::SMatrix::<_, #row_dim, #col_dim> + ::from_array_storage(nalgebra::ArrayStorage(#array_tokens)) + }; + + proc_macro::TokenStream::from(output) +} + +pub fn dmatrix_impl(stream: TokenStream) -> TokenStream { + let matrix = parse_macro_input!(stream as Matrix); + + let row_dim = matrix.nrows(); + let col_dim = matrix.ncols(); + + let array_tokens = matrix.to_col_major_flat_array_tokens(); + + // TODO: Use quote_spanned instead?? + let output = quote! { + nalgebra::DMatrix::<_> + ::from_vec_storage(nalgebra::VecStorage::new( + nalgebra::Dyn(#row_dim), + nalgebra::Dyn(#col_dim), + vec!#array_tokens)) + }; + + proc_macro::TokenStream::from(output) +} + +pub fn vector_impl(stream: TokenStream) -> TokenStream { + let vector = parse_macro_input!(stream as Vector); + let len = vector.len(); + let array_tokens = vector.to_array_tokens(); + let output = quote! { + nalgebra::SVector::<_, #len> + ::from_array_storage(nalgebra::ArrayStorage([#array_tokens])) + }; + proc_macro::TokenStream::from(output) +} + +pub fn dvector_impl(stream: TokenStream) -> TokenStream { + let vector = parse_macro_input!(stream as Vector); + let len = vector.len(); + let array_tokens = vector.to_array_tokens(); + let output = quote! { + nalgebra::DVector::<_> + ::from_vec_storage(nalgebra::VecStorage::new( + nalgebra::Dyn(#len), + nalgebra::Const::<1>, + vec!#array_tokens)) + }; + proc_macro::TokenStream::from(output) +} diff --git a/nalgebra-macros/src/stack_impl.rs b/nalgebra-macros/src/stack_impl.rs new file mode 100644 index 00000000..c2c9a397 --- /dev/null +++ b/nalgebra-macros/src/stack_impl.rs @@ -0,0 +1,302 @@ +use crate::Matrix; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{format_ident, quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::{Error, Expr, Lit}; + +#[allow(clippy::too_many_lines)] +pub fn stack_impl(matrix: Matrix) -> syn::Result { + // The prefix is used to construct variable names + // that are extremely unlikely to collide with variable names used in e.g. expressions + // by the user. Although we could use a long, pseudo-random string, this makes the generated + // code very painful to parse, so we settle for something more semantic that is still + // very unlikely to collide + let prefix = "___na"; + let n_block_rows = matrix.nrows(); + let n_block_cols = matrix.ncols(); + + let mut output = quote! {}; + + // First assign data and shape for each matrix entry to variables + // (this is important so that we, for example, don't evaluate an expression more than once) + for i in 0..n_block_rows { + for j in 0..n_block_cols { + let expr = &matrix[(i, j)]; + if !is_literal_zero(expr) { + let ident_block = format_ident!("{prefix}_stack_{i}_{j}_block"); + let ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape"); + output.extend(std::iter::once(quote_spanned! {expr.span()=> + let ref #ident_block = #expr; + let #ident_shape = #ident_block.shape_generic(); + })); + } + } + } + + // Determine the number of rows (dimension) in each block row, + // and write out variables that define block row dimensions and offsets into the + // output matrix + for i in 0..n_block_rows { + // The dimension of the block row is the result of trying to unify the row shape of + // all blocks in the block row + let dim = (0 ..n_block_cols) + .filter_map(|j| { + let expr = &matrix[(i, j)]; + if !is_literal_zero(expr) { + let mut ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape"); + ident_shape.set_span(ident_shape.span().located_at(expr.span())); + Some(quote_spanned!{expr.span()=> #ident_shape.0 }) + } else { + None + } + }).reduce(|a, b| { + let expect_msg = format!("All blocks in block row {i} must have the same number of rows"); + quote_spanned!{b.span()=> + >::representative(#a, #b) + .expect(#expect_msg) + } + }).ok_or(Error::new(Span::call_site(), format!("Block row {i} cannot consist entirely of implicit zero blocks.")))?; + + let dim_ident = format_ident!("{prefix}_stack_row_{i}_dim"); + let offset_ident = format_ident!("{prefix}_stack_row_{i}_offset"); + + let offset = if i == 0 { + quote! { 0 } + } else { + let prev_offset_ident = format_ident!("{prefix}_stack_row_{}_offset", i - 1); + let prev_dim_ident = format_ident!("{prefix}_stack_row_{}_dim", i - 1); + quote! { #prev_offset_ident + <_ as nalgebra::Dim>::value(&#prev_dim_ident) } + }; + + output.extend(std::iter::once(quote! { + let #dim_ident = #dim; + let #offset_ident = #offset; + })); + } + + // Do the same thing for the block columns + for j in 0..n_block_cols { + let dim = (0 ..n_block_rows) + .filter_map(|i| { + let expr = &matrix[(i, j)]; + if !is_literal_zero(expr) { + let mut ident_shape = format_ident!("{prefix}_stack_{i}_{j}_shape"); + ident_shape.set_span(ident_shape.span().located_at(expr.span())); + Some(quote_spanned!{expr.span()=> #ident_shape.1 }) + } else { + None + } + }).reduce(|a, b| { + let expect_msg = format!("All blocks in block column {j} must have the same number of columns"); + quote_spanned!{b.span()=> + >::representative(#a, #b) + .expect(#expect_msg) + } + }).ok_or(Error::new(Span::call_site(), format!("Block column {j} cannot consist entirely of implicit zero blocks.")))?; + + let dim_ident = format_ident!("{prefix}_stack_col_{j}_dim"); + let offset_ident = format_ident!("{prefix}_stack_col_{j}_offset"); + + let offset = if j == 0 { + quote! { 0 } + } else { + let prev_offset_ident = format_ident!("{prefix}_stack_col_{}_offset", j - 1); + let prev_dim_ident = format_ident!("{prefix}_stack_col_{}_dim", j - 1); + quote! { #prev_offset_ident + <_ as nalgebra::Dim>::value(&#prev_dim_ident) } + }; + + output.extend(std::iter::once(quote! { + let #dim_ident = #dim; + let #offset_ident = #offset; + })); + } + + // Determine number of rows and cols in output matrix, + // by adding together dimensions of all block rows/cols + let num_rows = (0..n_block_rows) + .map(|i| { + let ident = format_ident!("{prefix}_stack_row_{i}_dim"); + quote! { #ident } + }) + .reduce(|a, b| { + quote! { + <_ as nalgebra::DimAdd<_>>::add(#a, #b) + } + }) + .unwrap_or(quote! { nalgebra::dimension::U0 }); + + let num_cols = (0..n_block_cols) + .map(|j| { + let ident = format_ident!("{prefix}_stack_col_{j}_dim"); + quote! { #ident } + }) + .reduce(|a, b| { + quote! { + <_ as nalgebra::DimAdd<_>>::add(#a, #b) + } + }) + .unwrap_or(quote! { nalgebra::dimension::U0 }); + + // It should be possible to use `uninitialized_generic` here instead + // however that would mean that the macro needs to generate unsafe code + // which does not seem like a great idea. + output.extend(std::iter::once(quote! { + let mut matrix = nalgebra::Matrix::zeros_generic(#num_rows, #num_cols); + })); + + for i in 0..n_block_rows { + for j in 0..n_block_cols { + let row_dim = format_ident!("{prefix}_stack_row_{i}_dim"); + let col_dim = format_ident!("{prefix}_stack_col_{j}_dim"); + let row_offset = format_ident!("{prefix}_stack_row_{i}_offset"); + let col_offset = format_ident!("{prefix}_stack_col_{j}_offset"); + let expr = &matrix[(i, j)]; + if !is_literal_zero(expr) { + let expr_ident = format_ident!("{prefix}_stack_{i}_{j}_block"); + output.extend(std::iter::once(quote! { + let start = (#row_offset, #col_offset); + let shape = (#row_dim, #col_dim); + let input_view = #expr_ident.generic_view((0, 0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + })); + } + } + } + + Ok(quote! { + { + #output + matrix + } + }) +} + +fn is_literal_zero(expr: &Expr) -> bool { + matches!(expr, + Expr::Lit(syn::ExprLit { lit: Lit::Int(integer_literal), .. }) + if integer_literal.base10_digits() == "0") +} + +#[cfg(test)] +mod tests { + use crate::stack_impl::stack_impl; + use crate::Matrix; + use quote::quote; + + #[test] + fn stack_simple_generation() { + let input: Matrix = syn::parse_quote![ + a, 0; + 0, b; + ]; + + let result = stack_impl(input).unwrap(); + + let expected = quote! {{ + let ref ___na_stack_0_0_block = a; + let ___na_stack_0_0_shape = ___na_stack_0_0_block.shape_generic(); + let ref ___na_stack_1_1_block = b; + let ___na_stack_1_1_shape = ___na_stack_1_1_block.shape_generic(); + let ___na_stack_row_0_dim = ___na_stack_0_0_shape.0; + let ___na_stack_row_0_offset = 0; + let ___na_stack_row_1_dim = ___na_stack_1_1_shape.0; + let ___na_stack_row_1_offset = ___na_stack_row_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_0_dim); + let ___na_stack_col_0_dim = ___na_stack_0_0_shape.1; + let ___na_stack_col_0_offset = 0; + let ___na_stack_col_1_dim = ___na_stack_1_1_shape.1; + let ___na_stack_col_1_offset = ___na_stack_col_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_0_dim); + let mut matrix = nalgebra::Matrix::zeros_generic( + <_ as nalgebra::DimAdd<_>>::add(___na_stack_row_0_dim, ___na_stack_row_1_dim), + <_ as nalgebra::DimAdd<_>>::add(___na_stack_col_0_dim, ___na_stack_col_1_dim) + ); + let start = (___na_stack_row_0_offset, ___na_stack_col_0_offset); + let shape = (___na_stack_row_0_dim, ___na_stack_col_0_dim); + let input_view = ___na_stack_0_0_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + let start = (___na_stack_row_1_offset, ___na_stack_col_1_offset); + let shape = (___na_stack_row_1_dim, ___na_stack_col_1_dim); + let input_view = ___na_stack_1_1_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + matrix + }}; + + assert_eq!(format!("{result}"), format!("{}", expected)); + } + + #[test] + fn stack_complex_generation() { + let input: Matrix = syn::parse_quote![ + a, 0, b; + 0, c, d; + e, 0, 0; + ]; + + let result = stack_impl(input).unwrap(); + + let expected = quote! {{ + let ref ___na_stack_0_0_block = a; + let ___na_stack_0_0_shape = ___na_stack_0_0_block.shape_generic(); + let ref ___na_stack_0_2_block = b; + let ___na_stack_0_2_shape = ___na_stack_0_2_block.shape_generic(); + let ref ___na_stack_1_1_block = c; + let ___na_stack_1_1_shape = ___na_stack_1_1_block.shape_generic(); + let ref ___na_stack_1_2_block = d; + let ___na_stack_1_2_shape = ___na_stack_1_2_block.shape_generic(); + let ref ___na_stack_2_0_block = e; + let ___na_stack_2_0_shape = ___na_stack_2_0_block.shape_generic(); + let ___na_stack_row_0_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfRows < _ , _ >> :: representative (___na_stack_0_0_shape . 0 , ___na_stack_0_2_shape . 0) . expect ("All blocks in block row 0 must have the same number of rows") ; + let ___na_stack_row_0_offset = 0; + let ___na_stack_row_1_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfRows < _ , _ >> :: representative (___na_stack_1_1_shape . 0 , ___na_stack_1_2_shape . 0) . expect ("All blocks in block row 1 must have the same number of rows") ; + let ___na_stack_row_1_offset = ___na_stack_row_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_0_dim); + let ___na_stack_row_2_dim = ___na_stack_2_0_shape.0; + let ___na_stack_row_2_offset = ___na_stack_row_1_offset + <_ as nalgebra::Dim>::value(&___na_stack_row_1_dim); + let ___na_stack_col_0_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfColumns < _ , _ >> :: representative (___na_stack_0_0_shape . 1 , ___na_stack_2_0_shape . 1) . expect ("All blocks in block column 0 must have the same number of columns") ; + let ___na_stack_col_0_offset = 0; + let ___na_stack_col_1_dim = ___na_stack_1_1_shape.1; + let ___na_stack_col_1_offset = ___na_stack_col_0_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_0_dim); + let ___na_stack_col_2_dim = < nalgebra :: constraint :: ShapeConstraint as nalgebra :: constraint :: SameNumberOfColumns < _ , _ >> :: representative (___na_stack_0_2_shape . 1 , ___na_stack_1_2_shape . 1) . expect ("All blocks in block column 2 must have the same number of columns") ; + let ___na_stack_col_2_offset = ___na_stack_col_1_offset + <_ as nalgebra::Dim>::value(&___na_stack_col_1_dim); + let mut matrix = nalgebra::Matrix::zeros_generic( + <_ as nalgebra::DimAdd<_>>::add( + <_ as nalgebra::DimAdd<_>>::add(___na_stack_row_0_dim, ___na_stack_row_1_dim), + ___na_stack_row_2_dim + ), + <_ as nalgebra::DimAdd<_>>::add( + <_ as nalgebra::DimAdd<_>>::add(___na_stack_col_0_dim, ___na_stack_col_1_dim), + ___na_stack_col_2_dim + ) + ); + let start = (___na_stack_row_0_offset, ___na_stack_col_0_offset); + let shape = (___na_stack_row_0_dim, ___na_stack_col_0_dim); + let input_view = ___na_stack_0_0_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + let start = (___na_stack_row_0_offset, ___na_stack_col_2_offset); + let shape = (___na_stack_row_0_dim, ___na_stack_col_2_dim); + let input_view = ___na_stack_0_2_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + let start = (___na_stack_row_1_offset, ___na_stack_col_1_offset); + let shape = (___na_stack_row_1_dim, ___na_stack_col_1_dim); + let input_view = ___na_stack_1_1_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + let start = (___na_stack_row_1_offset, ___na_stack_col_2_offset); + let shape = (___na_stack_row_1_dim, ___na_stack_col_2_dim); + let input_view = ___na_stack_1_2_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + let start = (___na_stack_row_2_offset, ___na_stack_col_0_offset); + let shape = (___na_stack_row_2_dim, ___na_stack_col_0_dim); + let input_view = ___na_stack_2_0_block.generic_view((0,0), shape); + let mut output_view = matrix.generic_view_mut(start, shape); + output_view.copy_from(&input_view); + matrix + }}; + + assert_eq!(format!("{result}"), format!("{}", expected)); + } +} diff --git a/src/base/constraint.rs b/src/base/constraint.rs index 1960cb54..55a47684 100644 --- a/src/base/constraint.rs +++ b/src/base/constraint.rs @@ -19,6 +19,16 @@ pub trait DimEq { /// This is either equal to `D1` or `D2`, always choosing the one (if any) which is a type-level /// constant. type Representative: Dim; + + /// This constructs a value of type `Representative` with the + /// correct value + fn representative(d1: D1, d2: D2) -> Option { + if d1.value() != d2.value() { + None + } else { + Some(Self::Representative::from_usize(d1.value())) + } + } } impl DimEq for ShapeConstraint { @@ -41,6 +51,13 @@ macro_rules! equality_trait_decl( /// This is either equal to `D1` or `D2`, always choosing the one (if any) which is a type-level /// constant. type Representative: Dim; + + /// Returns a representative dimension instance if the two are equal, + /// otherwise `None`. + fn representative(d1: D1, d2: D2) -> Option<>::Representative> { + >::representative(d1, d2) + .map(|common_dim| >::Representative::from_usize(common_dim.value())) + } } impl $Trait for ShapeConstraint { diff --git a/src/lib.rs b/src/lib.rs index dac09827..4cfa6331 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -155,7 +155,7 @@ pub use crate::sparse::*; pub use base as core; #[cfg(feature = "macros")] -pub use nalgebra_macros::{dmatrix, dvector, matrix, point, vector}; +pub use nalgebra_macros::{dmatrix, dvector, matrix, point, stack, vector}; use simba::scalar::SupersetOf; use std::cmp::{self, Ordering, PartialOrd}; diff --git a/tests/lib.rs b/tests/lib.rs index 546aa8a7..3816013b 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -28,6 +28,9 @@ mod linalg; #[cfg(feature = "proptest-support")] mod proptest; +#[cfg(feature = "macros")] +mod macros; + //#[cfg(all(feature = "debug", feature = "compare", feature = "rand"))] //#[cfg(feature = "sparse")] //mod sparse; diff --git a/nalgebra-macros/tests/tests.rs b/tests/macros/matrix.rs similarity index 96% rename from nalgebra-macros/tests/tests.rs rename to tests/macros/matrix.rs index ed6353d0..096186fc 100644 --- a/nalgebra-macros/tests/tests.rs +++ b/tests/macros/matrix.rs @@ -1,3 +1,4 @@ +use crate::macros::assert_eq_and_type; use nalgebra::{ DMatrix, DVector, Matrix1x2, Matrix1x3, Matrix1x4, Matrix2, Matrix2x1, Matrix2x3, Matrix2x4, Matrix3, Matrix3x1, Matrix3x2, Matrix3x4, Matrix4, Matrix4x1, Matrix4x2, Matrix4x3, Point, @@ -6,16 +7,6 @@ use nalgebra::{ }; use nalgebra_macros::{dmatrix, dvector, matrix, point, vector}; -fn check_statically_same_type(_: &T, _: &T) {} - -/// Wrapper for `assert_eq` that also asserts that the types are the same -macro_rules! assert_eq_and_type { - ($left:expr, $right:expr $(,)?) => { - check_statically_same_type(&$left, &$right); - assert_eq!($left, $right); - }; -} - // Skip rustfmt because it just makes the test bloated without making it more readable #[rustfmt::skip] #[test] @@ -169,7 +160,7 @@ fn matrix_trybuild_tests() { let t = trybuild::TestCases::new(); // Verify error message when we give a matrix with mismatched dimensions - t.compile_fail("tests/trybuild/matrix_mismatched_dimensions.rs"); + t.compile_fail("tests/macros/trybuild/matrix_mismatched_dimensions.rs"); } #[test] @@ -177,7 +168,7 @@ fn dmatrix_trybuild_tests() { let t = trybuild::TestCases::new(); // Verify error message when we give a matrix with mismatched dimensions - t.compile_fail("tests/trybuild/dmatrix_mismatched_dimensions.rs"); + t.compile_fail("tests/macros/trybuild/dmatrix_mismatched_dimensions.rs"); } #[test] @@ -288,7 +279,7 @@ fn dmatrix_arbitrary_expressions() { let a = dmatrix![1 + 2 , 2 * 3; 4 * f(5 + 6), 7 - 8 * 9]; let a_expected = DMatrix::from_row_slice(2, 2, &[1 + 2 , 2 * 3, - 4 * f(5 + 6), 7 - 8 * 9]); + 4 * f(5 + 6), 7 - 8 * 9]); assert_eq_and_type!(a, a_expected); } diff --git a/tests/macros/mod.rs b/tests/macros/mod.rs new file mode 100644 index 00000000..50f6a75a --- /dev/null +++ b/tests/macros/mod.rs @@ -0,0 +1,19 @@ +mod matrix; +mod stack; + +/// Wrapper for `assert_eq` that also asserts that the types are the same +// For some reason, rustfmt totally messes up the formatting of this macro. +// For now we skip, but once https://github.com/rust-lang/rustfmt/issues/6131 +// is fixed, we can perhaps remove the skip attribute +#[rustfmt::skip] +macro_rules! assert_eq_and_type { + ($left:expr, $right:expr $(,)?) => { + { + fn check_statically_same_type(_: &T, _: &T) {} + check_statically_same_type(&$left, &$right); + } + assert_eq!($left, $right); + }; +} + +pub(crate) use assert_eq_and_type; diff --git a/tests/macros/stack.rs b/tests/macros/stack.rs new file mode 100644 index 00000000..7ba3af82 --- /dev/null +++ b/tests/macros/stack.rs @@ -0,0 +1,427 @@ +use crate::macros::assert_eq_and_type; +use cool_asserts::assert_panics; +use na::VecStorage; +use nalgebra::dimension::U1; +use nalgebra::{dmatrix, matrix, stack}; +use nalgebra::{ + DMatrix, DMatrixView, Dyn, Matrix, Matrix2, Matrix4, OMatrix, SMatrix, SMatrixView, + SMatrixViewMut, Scalar, U2, +}; +use nalgebra_macros::vector; +use num_traits::Zero; + +/// Simple implementation that stacks dynamic matrices. +/// +/// Used for verifying results of the stack! macro. `None` entries are considered to represent +/// a zero block. +fn stack_dyn(blocks: DMatrix>>) -> DMatrix { + let row_counts: Vec = blocks + .row_iter() + .map(|block_row| { + block_row + .iter() + .map(|block_or_implicit_zero| { + block_or_implicit_zero.as_ref().map(|block| block.nrows()) + }) + .reduce(|nrows1, nrows2| match (nrows1, nrows2) { + (Some(_), None) => nrows1, + (None, Some(_)) => nrows2, + (None, None) => None, + (Some(nrows1), Some(nrows2)) if nrows1 == nrows2 => Some(nrows1), + _ => panic!("Number of rows must be consistent in each block row"), + }) + .unwrap_or(Some(0)) + .expect("Each block row must have at least one entry which is not a zero literal") + }) + .collect(); + let col_counts: Vec = blocks + .column_iter() + .map(|block_col| { + block_col + .iter() + .map(|block_or_implicit_zero| { + block_or_implicit_zero.as_ref().map(|block| block.ncols()) + }) + .reduce(|ncols1, ncols2| match (ncols1, ncols2) { + (Some(_), None) => ncols1, + (None, Some(_)) => ncols2, + (None, None) => None, + (Some(ncols1), Some(ncols2)) if ncols1 == ncols2 => Some(ncols1), + _ => panic!("Number of columns must be consistent in each block column"), + }) + .unwrap_or(Some(0)) + .expect( + "Each block column must have at least one entry which is not a zero literal", + ) + }) + .collect(); + + let nrows_total = row_counts.iter().sum(); + let ncols_total = col_counts.iter().sum(); + let mut output = DMatrix::zeros(nrows_total, ncols_total); + + let mut col_offset = 0; + for j in 0..blocks.ncols() { + let mut row_offset = 0; + for i in 0..blocks.nrows() { + if let Some(input_ij) = &blocks[(i, j)] { + let (block_nrows, block_ncols) = input_ij.shape(); + output + .view_mut((row_offset, col_offset), (block_nrows, block_ncols)) + .copy_from(&input_ij); + } + row_offset += row_counts[i]; + } + col_offset += col_counts[j]; + } + + output +} + +macro_rules! stack_dyn_convert_to_dmatrix_option { + (0) => { + None + }; + ($entry:expr) => { + Some($entry.as_view::().clone_owned()) + }; +} + +/// Helper macro that compares the result of stack! with a simplified implementation that +/// works only with heap-allocated data. +/// +/// This implementation is essentially radically different to the implementation in stack!, +/// so if they both match, then it's a good sign that the stack! impl is correct. +macro_rules! verify_stack { + ($matrix_type:ty ; [$($($entry:expr),*);*]) => { + { + // Our input has the same syntax as the stack! macro (and matrix! macro, for that matter) + let stack_result: $matrix_type = stack![$($($entry),*);*]; + // Use the dmatrix! macro to nest matrices into each other + let dyn_result = stack_dyn( + dmatrix![$($(stack_dyn_convert_to_dmatrix_option!($entry)),*);*] + ); + // println!("{}", stack_result); + // println!("{}", dyn_result); + assert_eq!(stack_result, dyn_result); + } + } +} + +#[test] +fn stack_simple() { + let m = stack![ + Matrix2::::identity(), 0; + 0, &Matrix2::identity(); + ]; + + assert_eq_and_type!(m, Matrix4::identity()); +} + +#[test] +fn stack_diag() { + let m = stack![ + 0, matrix![1, 2; 3, 4;]; + matrix![5, 6; 7, 8;], 0; + ]; + + let res = matrix![ + 0, 0, 1, 2; + 0, 0, 3, 4; + 5, 6, 0, 0; + 7, 8, 0, 0; + ]; + + assert_eq_and_type!(m, res); +} + +#[test] +fn stack_dynamic() { + let m = stack![ + matrix![ 1, 2; 3, 4; ], 0; + 0, dmatrix![7, 8, 9; 10, 11, 12; ]; + ]; + + let res = dmatrix![ + 1, 2, 0, 0, 0; + 3, 4, 0, 0, 0; + 0, 0, 7, 8, 9; + 0, 0, 10, 11, 12; + ]; + + assert_eq_and_type!(m, res); +} + +#[test] +fn stack_nested() { + let m = stack![ + stack![ matrix![1, 2; 3, 4;]; matrix![5, 6;]], + stack![ matrix![7;9;10;], matrix![11; 12; 13;] ]; + ]; + + let res = matrix![ + 1, 2, 7, 11; + 3, 4, 9, 12; + 5, 6, 10, 13; + ]; + + assert_eq_and_type!(m, res); +} + +#[test] +fn stack_single() { + let a = matrix![1, 2; 3, 4]; + let b = stack![a]; + + assert_eq_and_type!(a, b); +} + +#[test] +fn stack_single_row() { + let a = matrix![1, 2; 3, 4]; + let m = stack![a, a]; + + let res = matrix![ + 1, 2, 1, 2; + 3, 4, 3, 4; + ]; + + assert_eq_and_type!(m, res); +} + +#[test] +fn stack_single_col() { + let a = matrix![1, 2; 3, 4]; + let m = stack![a; a]; + + let res = matrix![ + 1, 2; + 3, 4; + 1, 2; + 3, 4; + ]; + + assert_eq_and_type!(m, res); +} + +#[test] +#[rustfmt::skip] +fn stack_expr() { + let a = matrix![1, 2; 3, 4]; + let b = matrix![5, 6; 7, 8]; + let m = stack![a + b; 2i32 * b - a]; + + let res = matrix![ + 6, 8; + 10, 12; + 9, 10; + 11, 12; + ]; + + assert_eq_and_type!(m, res); +} + +#[test] +fn stack_edge_cases() { + { + // Empty stack should return zero matrix with specified type + let _: SMatrix = stack![]; + let _: SMatrix = stack![]; + } + + { + // Case suggested by @tpdickso: https://github.com/dimforge/nalgebra/pull/1080#discussion_r1435871752 + let a = matrix![1, 2; + 3, 4]; + let b = DMatrix::from_data(VecStorage::new(Dyn(2), Dyn(0), vec![])); + assert_eq!( + stack![a, 0; + 0, b], + matrix![1, 2; + 3, 4; + 0, 0; + 0, 0] + ); + } +} + +#[rustfmt::skip] +#[test] +fn stack_many_tests() { + // s prefix means static, d prefix means dynamic + // Static matrices + let s_0x0: SMatrix = matrix![]; + let s_0x1: SMatrix = Matrix::default(); + let s_1x0: SMatrix = Matrix::default(); + let s_1x1: SMatrix = matrix![1]; + let s_2x2: SMatrix = matrix![6, 7; 8, 9]; + let s_2x3: SMatrix = matrix![16, 17, 18; 19, 20, 21]; + let s_3x3: SMatrix = matrix![28, 29, 30; 31, 32, 33; 34, 35, 36]; + + // Dynamic matrices + let d_0x0: DMatrix = dmatrix![]; + let d_1x2: DMatrix = dmatrix![9, 10]; + let d_2x2: DMatrix = dmatrix![5, 6; 7, 8]; + let d_4x4: DMatrix = dmatrix![10, 11, 12, 13; 14, 15, 16, 17; 18, 19, 20, 21; 22, 23, 24, 25]; + + // Check for weirdness with matrices that have zero row/cols + verify_stack!(SMatrix<_, 0, 0>; [s_0x0]); + verify_stack!(SMatrix<_, 0, 1>; [s_0x1]); + verify_stack!(SMatrix<_, 1, 0>; [s_1x0]); + verify_stack!(SMatrix<_, 0, 0>; [s_0x0; s_0x0]); + verify_stack!(SMatrix<_, 0, 0>; [s_0x0, s_0x0; s_0x0, s_0x0]); + verify_stack!(SMatrix<_, 0, 2>; [s_0x1, s_0x1]); + verify_stack!(SMatrix<_, 2, 0>; [s_1x0; s_1x0]); + verify_stack!(SMatrix<_, 1, 0>; [s_1x0, s_1x0]); + verify_stack!(DMatrix<_>; [d_0x0]); + + // Horizontal stacking + verify_stack!(SMatrix<_, 1, 2>; [s_1x1, s_1x1]); + verify_stack!(SMatrix<_, 2, 4>; [s_2x2, s_2x2]); + verify_stack!(DMatrix<_>; [d_1x2, d_1x2]); + + // Vertical stacking + verify_stack!(SMatrix<_, 2, 1>; [s_1x1; s_1x1]); + verify_stack!(SMatrix<_, 4, 2>; [s_2x2; s_2x2]); + verify_stack!(DMatrix<_>; [d_2x2; d_2x2]); + + // Mix static and dynamic matrices + verify_stack!(OMatrix<_, U2, Dyn>; [s_2x2, d_2x2]); + verify_stack!(OMatrix<_, Dyn, U2>; [s_2x2; d_1x2]); + + // Stack more than two matrices + verify_stack!(SMatrix<_, 1, 3>; [s_1x1, s_1x1, s_1x1]); + verify_stack!(DMatrix<_>; [d_1x2, d_1x2, d_1x2]); + + // Slightly larger dims + verify_stack!(SMatrix<_, 3, 6>; [s_3x3, s_3x3]); + verify_stack!(DMatrix<_>; [d_4x4; d_4x4]); + verify_stack!(SMatrix<_, 4, 7>; [s_2x2, s_2x3, d_2x2; + d_2x2, s_2x3, s_2x2]); + + // Mix of references and owned + verify_stack!(OMatrix<_, Dyn, U2>; [&s_2x2; &d_1x2]); + verify_stack!(SMatrix<_, 4, 7>; [ s_2x2, &s_2x3, d_2x2; + &d_2x2, s_2x3, &s_2x2]); + + // Views + let s_2x2_v: SMatrixView<_, 2, 2> = s_2x2.as_view(); + let s_2x3_v: SMatrixView<_, 2, 3> = s_2x3.as_view(); + let d_2x2_v: DMatrixView<_> = d_2x2.as_view(); + let mut s_2x2_vm = s_2x2.clone(); + let s_2x2_vm: SMatrixViewMut<_, 2, 2> = s_2x2_vm.as_view_mut(); + let mut s_2x3_vm = s_2x3.clone(); + let s_2x3_vm: SMatrixViewMut<_, 2, 3> = s_2x3_vm.as_view_mut(); + verify_stack!(SMatrix<_, 4, 7>; [ s_2x2_vm, &s_2x3_vm, d_2x2_v; + &d_2x2_v, s_2x3_v, &s_2x2_v]); + + // Expressions + let matrix_fn = |matrix: &DMatrix<_>| matrix.map(|x_ij| x_ij * 3); + verify_stack!(SMatrix<_, 2, 5>; [ 2 * s_2x2 - 3 * &d_2x2, s_2x3 + 2 * s_2x3]); + verify_stack!(DMatrix<_>; [ 2 * matrix_fn(&d_2x2) ]); + verify_stack!(SMatrix<_, 2, 5>; [ (|matrix| 4 * matrix)(s_2x2), s_2x3 ]); +} + +#[test] +fn stack_trybuild_tests() { + let t = trybuild::TestCases::new(); + + // Verify error message when a row or column only contains a zero entry + t.compile_fail("tests/macros/trybuild/stack_empty_row.rs"); + t.compile_fail("tests/macros/trybuild/stack_empty_col.rs"); + t.compile_fail("tests/macros/trybuild/stack_incompatible_block_dimensions.rs"); + t.compile_fail("tests/macros/trybuild/stack_incompatible_block_dimensions2.rs"); +} + +#[test] +fn stack_mismatched_dimensions_runtime_panics() { + // s prefix denotes static, d dynamic + let s_2x2 = matrix![1, 2; 3, 4]; + let d_2x3 = dmatrix![5, 6, 7; 8, 9, 10]; + let d_1x2 = dmatrix![11, 12]; + let d_1x3 = dmatrix![13, 14, 15]; + + assert_panics!( + stack![s_2x2, d_1x2], + includes("All blocks in block row 0 must have the same number of rows") + ); + + assert_panics!( + stack![s_2x2; d_2x3], + includes("All blocks in block column 0 must have the same number of columns") + ); + + assert_panics!( + stack![s_2x2, s_2x2; d_1x2, d_2x3], + includes("All blocks in block row 1 must have the same number of rows") + ); + + assert_panics!( + stack![s_2x2, s_2x2; d_1x2, d_1x3], + includes("All blocks in block column 1 must have the same number of columns") + ); + + assert_panics!( + { + // Edge case suggested by @tpdickso: https://github.com/dimforge/nalgebra/pull/1080#discussion_r1435871752 + let d_3x0 = DMatrix::from_data(VecStorage::new(Dyn(3), Dyn(0), Vec::::new())); + stack![s_2x2, d_3x0] + }, + includes("All blocks in block row 0 must have the same number of rows") + ); +} + +#[test] +fn stack_test_builtin_types() { + // Other than T: Zero, there's nothing type-specific in the logic for stack! + // These tests are just sanity tests, to make sure it works with the common built-in types + let a = matrix![1, 2; 3, 4]; + let b = vector![5, 6]; + let c = matrix![7, 8]; + + let expected = matrix![ 1, 2, 5; + 3, 4, 6; + 7, 8, 0 ]; + + macro_rules! check_builtin { + ($T:ty) => {{ + // Cannot use .cast::<$T> because we cannot convert between unsigned and signed + let stacked = stack![a.map(|a_ij| a_ij as $T), b.map(|b_ij| b_ij as $T); + c.map(|c_ij| c_ij as $T), 0]; + assert_eq!(stacked, expected.map(|e_ij| e_ij as $T)); + }} + } + + check_builtin!(i8); + check_builtin!(i16); + check_builtin!(i32); + check_builtin!(i64); + check_builtin!(i128); + check_builtin!(u8); + check_builtin!(u16); + check_builtin!(u32); + check_builtin!(u64); + check_builtin!(u128); + check_builtin!(f32); + check_builtin!(f64); +} + +#[test] +fn stack_test_complex() { + use num_complex::Complex as C; + type C32 = C; + let a = matrix![C::new(1.0, 1.0), C::new(2.0, 2.0); C::new(3.0, 3.0), C::new(4.0, 4.0)]; + let b = vector![C::new(5.0, 5.0), C::new(6.0, 6.0)]; + let c = matrix![C::new(7.0, 7.0), C::new(8.0, 8.0)]; + + let expected = matrix![ 1, 2, 5; + 3, 4, 6; + 7, 8, 0 ] + .map(|x| C::new(x as f64, x as f64)); + + assert_eq!(stack![a, b; c, 0], expected); + assert_eq!( + stack![a.cast::(), b.cast::(); c.cast::(), 0], + expected.cast::() + ); +} diff --git a/nalgebra-macros/tests/trybuild/dmatrix_mismatched_dimensions.rs b/tests/macros/trybuild/dmatrix_mismatched_dimensions.rs similarity index 64% rename from nalgebra-macros/tests/trybuild/dmatrix_mismatched_dimensions.rs rename to tests/macros/trybuild/dmatrix_mismatched_dimensions.rs index 786b5849..89b5a2be 100644 --- a/nalgebra-macros/tests/trybuild/dmatrix_mismatched_dimensions.rs +++ b/tests/macros/trybuild/dmatrix_mismatched_dimensions.rs @@ -1,4 +1,4 @@ -use nalgebra_macros::dmatrix; +use nalgebra::dmatrix; fn main() { dmatrix![1, 2, 3; diff --git a/nalgebra-macros/tests/trybuild/dmatrix_mismatched_dimensions.stderr b/tests/macros/trybuild/dmatrix_mismatched_dimensions.stderr similarity index 64% rename from nalgebra-macros/tests/trybuild/dmatrix_mismatched_dimensions.stderr rename to tests/macros/trybuild/dmatrix_mismatched_dimensions.stderr index eaedc650..adcd7f00 100644 --- a/nalgebra-macros/tests/trybuild/dmatrix_mismatched_dimensions.stderr +++ b/tests/macros/trybuild/dmatrix_mismatched_dimensions.stderr @@ -1,5 +1,5 @@ error: Unexpected number of entries in row 1. Expected 3, found 2 entries. - --> $DIR/dmatrix_mismatched_dimensions.rs:5:13 + --> tests/macros/trybuild/dmatrix_mismatched_dimensions.rs:5:13 | 5 | 4, 5]; | ^ diff --git a/nalgebra-macros/tests/trybuild/matrix_mismatched_dimensions.rs b/tests/macros/trybuild/matrix_mismatched_dimensions.rs similarity index 65% rename from nalgebra-macros/tests/trybuild/matrix_mismatched_dimensions.rs rename to tests/macros/trybuild/matrix_mismatched_dimensions.rs index c5eb87b7..1ce845dc 100644 --- a/nalgebra-macros/tests/trybuild/matrix_mismatched_dimensions.rs +++ b/tests/macros/trybuild/matrix_mismatched_dimensions.rs @@ -1,4 +1,4 @@ -use nalgebra_macros::matrix; +use nalgebra::matrix; fn main() { matrix![1, 2, 3; diff --git a/nalgebra-macros/tests/trybuild/matrix_mismatched_dimensions.stderr b/tests/macros/trybuild/matrix_mismatched_dimensions.stderr similarity index 65% rename from nalgebra-macros/tests/trybuild/matrix_mismatched_dimensions.stderr rename to tests/macros/trybuild/matrix_mismatched_dimensions.stderr index c83e8d0c..87f33f99 100644 --- a/nalgebra-macros/tests/trybuild/matrix_mismatched_dimensions.stderr +++ b/tests/macros/trybuild/matrix_mismatched_dimensions.stderr @@ -1,5 +1,5 @@ error: Unexpected number of entries in row 1. Expected 3, found 2 entries. - --> $DIR/matrix_mismatched_dimensions.rs:5:13 + --> tests/macros/trybuild/matrix_mismatched_dimensions.rs:5:13 | 5 | 4, 5]; | ^ diff --git a/tests/macros/trybuild/stack_empty_col.rs b/tests/macros/trybuild/stack_empty_col.rs new file mode 100644 index 00000000..e743463d --- /dev/null +++ b/tests/macros/trybuild/stack_empty_col.rs @@ -0,0 +1,6 @@ +use nalgebra::{matrix, stack}; + +fn main() { + let m = matrix![1, 2; 3, 4]; + stack![0, m]; +} diff --git a/tests/macros/trybuild/stack_empty_col.stderr b/tests/macros/trybuild/stack_empty_col.stderr new file mode 100644 index 00000000..2ba6de94 --- /dev/null +++ b/tests/macros/trybuild/stack_empty_col.stderr @@ -0,0 +1,7 @@ +error: Block column 0 cannot consist entirely of implicit zero blocks. + --> tests/macros/trybuild/stack_empty_col.rs:5:5 + | +5 | stack![0, m]; + | ^^^^^^^^^^^^ + | + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/macros/trybuild/stack_empty_row.rs b/tests/macros/trybuild/stack_empty_row.rs new file mode 100644 index 00000000..264b2727 --- /dev/null +++ b/tests/macros/trybuild/stack_empty_row.rs @@ -0,0 +1,6 @@ +use nalgebra::{matrix, stack}; + +fn main() { + let m = matrix![1, 2; 3, 4]; + stack![0; m]; +} diff --git a/tests/macros/trybuild/stack_empty_row.stderr b/tests/macros/trybuild/stack_empty_row.stderr new file mode 100644 index 00000000..6a834a9d --- /dev/null +++ b/tests/macros/trybuild/stack_empty_row.stderr @@ -0,0 +1,7 @@ +error: Block row 0 cannot consist entirely of implicit zero blocks. + --> tests/macros/trybuild/stack_empty_row.rs:5:5 + | +5 | stack![0; m]; + | ^^^^^^^^^^^^ + | + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/macros/trybuild/stack_incompatible_block_dimensions.rs b/tests/macros/trybuild/stack_incompatible_block_dimensions.rs new file mode 100644 index 00000000..8a26807c --- /dev/null +++ b/tests/macros/trybuild/stack_incompatible_block_dimensions.rs @@ -0,0 +1,13 @@ +use nalgebra::{matrix, stack}; + +fn main() { + // Use multi-letter names for checking that the reported span comes out correctly + let a11 = matrix![1, 2; + 3, 4]; + let a12 = matrix![5, 6; + 7, 8]; + let a21 = matrix![9, 10, 11]; + let a22 = matrix![12, 13]; + stack![a11, a12; + a21, a22]; +} \ No newline at end of file diff --git a/tests/macros/trybuild/stack_incompatible_block_dimensions.stderr b/tests/macros/trybuild/stack_incompatible_block_dimensions.stderr new file mode 100644 index 00000000..ffef61a8 --- /dev/null +++ b/tests/macros/trybuild/stack_incompatible_block_dimensions.stderr @@ -0,0 +1,37 @@ +error[E0277]: the trait bound `ShapeConstraint: SameNumberOfColumns, Const<3>>` is not satisfied + --> tests/macros/trybuild/stack_incompatible_block_dimensions.rs:12:12 + | +12 | a21, a22]; + | ^^^ the trait `SameNumberOfColumns, Const<3>>` is not implemented for `ShapeConstraint` + | + = help: the following other types implement trait `SameNumberOfColumns`: + > + > + > + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0282]: type annotations needed + --> tests/macros/trybuild/stack_incompatible_block_dimensions.rs:11:5 + | +11 | / stack![a11, a12; +12 | | a21, a22]; + | |____________________^ cannot infer type + | + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0599]: no method named `generic_view_mut` found for struct `Matrix<_, Const<3>, _, _>` in the current scope + --> tests/macros/trybuild/stack_incompatible_block_dimensions.rs:11:5 + | +11 | stack![a11, a12; + | _____^ +12 | | a21, a22]; + | |____________________^ method not found in `Matrix<_, Const<3>, _, _>` + | + ::: src/base/matrix_view.rs + | + | generic_slice_mut => generic_view_mut, + | ---------------- the method is available for `Matrix<_, Const<3>, _, _>` here + | + = note: the method was found for + - `Matrix` + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/macros/trybuild/stack_incompatible_block_dimensions2.rs b/tests/macros/trybuild/stack_incompatible_block_dimensions2.rs new file mode 100644 index 00000000..2d5dafdc --- /dev/null +++ b/tests/macros/trybuild/stack_incompatible_block_dimensions2.rs @@ -0,0 +1,14 @@ +use nalgebra::{matrix, stack}; + +fn main() { + // Use multi-letter names for checking that the reported span comes out correctly + let a11 = matrix![1, 2; + 3, 4]; + let a12 = matrix![5, 6; + 7, 8]; + let a21 = matrix![9, 10]; + let a22 = matrix![11, 12; + 13, 14]; + stack![a11, a12; + a21, a22]; +} \ No newline at end of file diff --git a/tests/macros/trybuild/stack_incompatible_block_dimensions2.stderr b/tests/macros/trybuild/stack_incompatible_block_dimensions2.stderr new file mode 100644 index 00000000..2d52e26b --- /dev/null +++ b/tests/macros/trybuild/stack_incompatible_block_dimensions2.stderr @@ -0,0 +1,37 @@ +error[E0277]: the trait bound `ShapeConstraint: SameNumberOfRows, Const<2>>` is not satisfied + --> tests/macros/trybuild/stack_incompatible_block_dimensions2.rs:13:17 + | +13 | a21, a22]; + | ^^^ the trait `SameNumberOfRows, Const<2>>` is not implemented for `ShapeConstraint` + | + = help: the following other types implement trait `SameNumberOfRows`: + > + > + > + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0282]: type annotations needed + --> tests/macros/trybuild/stack_incompatible_block_dimensions2.rs:12:5 + | +12 | / stack![a11, a12; +13 | | a21, a22]; + | |____________________^ cannot infer type + | + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0599]: no method named `generic_view_mut` found for struct `Matrix<_, _, Const<4>, _>` in the current scope + --> tests/macros/trybuild/stack_incompatible_block_dimensions2.rs:12:5 + | +12 | stack![a11, a12; + | _____^ +13 | | a21, a22]; + | |____________________^ method not found in `Matrix<_, _, Const<4>, _>` + | + ::: src/base/matrix_view.rs + | + | generic_slice_mut => generic_view_mut, + | ---------------- the method is available for `Matrix<_, _, Const<4>, _>` here + | + = note: the method was found for + - `Matrix` + = note: this error originates in the macro `stack` (in Nightly builds, run with -Z macro-backtrace for more info)