diff --git a/nalgebra-macros/src/lib.rs b/nalgebra-macros/src/lib.rs index 154f4876..06f6a2cc 100644 --- a/nalgebra-macros/src/lib.rs +++ b/nalgebra-macros/src/lib.rs @@ -4,9 +4,12 @@ use syn::{Expr}; use syn::parse::{Parse, ParseStream, Result, Error}; use syn::punctuated::{Punctuated}; use syn::{parse_macro_input, Token}; -use quote::{quote, format_ident}; +use quote::{quote, TokenStreamExt}; use proc_macro::TokenStream; +use proc_macro2::{TokenStream as TokenStream2, Delimiter, TokenTree, Spacing}; +use proc_macro2::{Group, Punct}; + struct Matrix { // Represent the matrix as a row-major vector of vectors of expressions rows: Vec>, @@ -22,14 +25,18 @@ impl Matrix { self.ncols } - fn to_col_major_repr(&self) -> Vec { - let mut data = Vec::with_capacity(self.nrows() * self.ncols()); + /// Produces a stream of tokens representing this matrix as a column-major array. + fn col_major_array_tokens(&self) -> TokenStream2 { + let mut result = TokenStream2::new(); for j in 0 .. self.ncols() { - for i in 0 .. self.nrows() { - data.push(self.rows[i][j].clone()); - } + let mut col = TokenStream2::new(); + let col_iter = (0 .. self.nrows()) + .map(move |i| &self.rows[i][j]); + col.append_separated(col_iter, Punct::new(',', Spacing::Alone)); + result.append(Group::new(Delimiter::Bracket, col)); + result.append(Punct::new(',', Spacing::Alone)); } - data + TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result))) } } @@ -79,13 +86,16 @@ pub fn matrix(stream: TokenStream) -> TokenStream { let row_dim = matrix.nrows(); let col_dim = matrix.ncols(); - let entries_col_major = matrix.to_col_major_repr(); + + let array_tokens = matrix.col_major_array_tokens(); // TODO: Use quote_spanned instead?? - // TODO: Construct directly from array? + // TODO: Avoid use of unsafe here let output = quote! { + unsafe { nalgebra::SMatrix::<_, #row_dim, #col_dim> - ::from_column_slice(&[#(#entries_col_major),*]) + ::from_data_statically_unchecked(nalgebra::ArrayStorage(#array_tokens)) + } }; proc_macro::TokenStream::from(output) diff --git a/nalgebra-macros/tests/tests.rs b/nalgebra-macros/tests/tests.rs index 5d5a91b0..1113a79d 100644 --- a/nalgebra-macros/tests/tests.rs +++ b/nalgebra-macros/tests/tests.rs @@ -1,5 +1,5 @@ use nalgebra_macros::matrix; -use nalgebra::{SMatrix, Matrix3x2, U0, U1, Matrix1x2, Matrix1x3, Matrix1x4, Matrix2x1, Matrix2, Matrix2x3, Matrix2x4, Matrix3x1, Matrix3, Matrix3x4, Matrix4x1, Matrix4x2, Matrix4x3, Matrix4}; +use nalgebra::{SMatrix, Matrix3x2, Matrix1x2, Matrix1x3, Matrix1x4, Matrix2x1, Matrix2, Matrix2x3, Matrix2x4, Matrix3x1, Matrix3, Matrix3x4, Matrix4x1, Matrix4x2, Matrix4x3, Matrix4}; #[test] fn matrix_small_dims_exhaustive() {