From 5c843022c2691664ec20447f406c56c9d9776944 Mon Sep 17 00:00:00 2001 From: Andreas Longva Date: Thu, 29 Apr 2021 17:06:52 +0200 Subject: [PATCH] Implement dmatrix![] macro --- nalgebra-macros/src/lib.rs | 42 ++++++++++++++++++++++++++++++---- nalgebra-macros/tests/tests.rs | 39 ++++++++++++++++++++++++++++--- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/nalgebra-macros/src/lib.rs b/nalgebra-macros/src/lib.rs index a18162ab..c8b9b421 100644 --- a/nalgebra-macros/src/lib.rs +++ b/nalgebra-macros/src/lib.rs @@ -4,7 +4,7 @@ use syn::{Expr}; use syn::parse::{Parse, ParseStream, Result, Error}; use syn::punctuated::{Punctuated}; use syn::{parse_macro_input, Token}; -use quote::{quote, TokenStreamExt}; +use quote::{quote, TokenStreamExt, ToTokens}; use proc_macro::TokenStream; use proc_macro2::{TokenStream as TokenStream2, Delimiter, TokenTree, Spacing}; @@ -25,8 +25,8 @@ impl Matrix { self.ncols } - /// Produces a stream of tokens representing this matrix as a column-major array. - fn col_major_array_tokens(&self) -> TokenStream2 { + /// Produces a stream of tokens representing this matrix as a column-major nested array. + fn to_col_major_nested_array_tokens(&self) -> TokenStream2 { let mut result = TokenStream2::new(); for j in 0 .. self.ncols() { let mut col = TokenStream2::new(); @@ -38,6 +38,19 @@ impl Matrix { } TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result))) } + + /// Produces a stream of tokens representing this matrix as a column-major flat array + /// (suitable for representing e.g. a `DMatrix`). + fn to_col_major_flat_array_tokens(&self) -> TokenStream2 { + let mut data = TokenStream2::new(); + for j in 0 .. self.ncols() { + for i in 0 .. self.nrows() { + self.rows[i][j].to_tokens(&mut data); + data.append(Punct::new(',', Spacing::Alone)); + } + } + TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data))) + } } type MatrixRowSyntax = Punctuated; @@ -87,7 +100,7 @@ pub fn matrix(stream: TokenStream) -> TokenStream { let row_dim = matrix.nrows(); let col_dim = matrix.ncols(); - let array_tokens = matrix.col_major_array_tokens(); + let array_tokens = matrix.to_col_major_nested_array_tokens(); // TODO: Use quote_spanned instead?? let output = quote! { @@ -95,5 +108,26 @@ pub fn matrix(stream: TokenStream) -> TokenStream { ::from_array_storage(nalgebra::ArrayStorage(#array_tokens)) }; + proc_macro::TokenStream::from(output) +} + +#[proc_macro] +pub fn dmatrix(stream: TokenStream) -> TokenStream { + let matrix = parse_macro_input!(stream as Matrix); + + let row_dim = matrix.nrows(); + let col_dim = matrix.ncols(); + + let array_tokens = matrix.to_col_major_flat_array_tokens(); + + // TODO: Use quote_spanned instead?? + let output = quote! { + nalgebra::DMatrix::<_> + ::from_vec_storage(nalgebra::VecStorage::new( + nalgebra::Dynamic::new(#row_dim), + nalgebra::Dynamic::new(#col_dim), + vec!#array_tokens)) + }; + proc_macro::TokenStream::from(output) } \ No newline at end of file diff --git a/nalgebra-macros/tests/tests.rs b/nalgebra-macros/tests/tests.rs index 0b594b31..8610d07e 100644 --- a/nalgebra-macros/tests/tests.rs +++ b/nalgebra-macros/tests/tests.rs @@ -1,5 +1,5 @@ -use nalgebra_macros::matrix; -use nalgebra::{SMatrix, Matrix3x2, Matrix1x2, Matrix1x3, Matrix1x4, Matrix2x1, Matrix2, Matrix2x3, Matrix2x4, Matrix3x1, Matrix3, Matrix3x4, Matrix4x1, Matrix4x2, Matrix4x3, Matrix4}; +use nalgebra_macros::{dmatrix, matrix}; +use nalgebra::{DMatrix, SMatrix, Matrix3x2, Matrix1x2, Matrix1x3, Matrix1x4, Matrix2x1, Matrix2, Matrix2x3, Matrix2x4, Matrix3x1, Matrix3, Matrix3x4, Matrix4x1, Matrix4x2, Matrix4x3, Matrix4}; #[test] fn matrix_small_dims_exhaustive() { @@ -40,4 +40,37 @@ fn matrix_const_fn() { const _: SMatrix = matrix![]; const _: SMatrix = matrix![1, 2]; const _: SMatrix = matrix![1, 2, 3; 4, 5, 6]; -} \ No newline at end of file +} + +#[test] +fn dmatrix_small_dims_exhaustive() { + // 0x0 + assert_eq!(dmatrix![], DMatrix::::zeros(0, 0)); + + // // 1xN + assert_eq!(dmatrix![1], SMatrix::::new(1)); + assert_eq!(dmatrix![1, 2], Matrix1x2::new(1, 2)); + assert_eq!(dmatrix![1, 2, 3], Matrix1x3::new(1, 2, 3)); + assert_eq!(dmatrix![1, 2, 3, 4], Matrix1x4::new(1, 2, 3, 4)); + + // 2xN + assert_eq!(dmatrix![1; 2], Matrix2x1::new(1, 2)); + assert_eq!(dmatrix![1, 2; 3, 4], Matrix2::new(1, 2, 3, 4)); + assert_eq!(dmatrix![1, 2, 3; 4, 5, 6], Matrix2x3::new(1, 2, 3, 4, 5, 6)); + assert_eq!(dmatrix![1, 2, 3, 4; 5, 6, 7, 8], Matrix2x4::new(1, 2, 3, 4, 5, 6, 7, 8)); + + // 3xN + assert_eq!(dmatrix![1; 2; 3], Matrix3x1::new(1, 2, 3)); + assert_eq!(dmatrix![1, 2; 3, 4; 5, 6], Matrix3x2::new(1, 2, 3, 4, 5, 6)); + assert_eq!(dmatrix![1, 2, 3; 4, 5, 6; 7, 8, 9], Matrix3::new(1, 2, 3, 4, 5, 6, 7, 8, 9)); + assert_eq!(dmatrix![1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12], + Matrix3x4::new(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); + + // 4xN + assert_eq!(dmatrix![1; 2; 3; 4], Matrix4x1::new(1, 2, 3, 4)); + assert_eq!(dmatrix![1, 2; 3, 4; 5, 6; 7, 8], Matrix4x2::new(1, 2, 3, 4, 5, 6, 7, 8)); + assert_eq!(dmatrix![1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12], + Matrix4x3::new(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); + assert_eq!(dmatrix![1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12; 13, 14, 15, 16], + Matrix4::new(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); +}