forked from M-Labs/nalgebra
Initial impl using syn and quote
This commit is contained in:
parent
e97692255b
commit
ab95cf7020
@ -10,7 +10,10 @@ proc-macro = true
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
syn = "1.0"
|
# TODO: Determine minimal features that we need
|
||||||
|
syn = { version="1.0", features = ["full"] }
|
||||||
|
quote = "1.0"
|
||||||
|
proc-macro2 = "1.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
nalgebra = { version = "0.25.4", path = ".." }
|
nalgebra = { version = "0.25.4", path = ".." }
|
||||||
|
@ -1,99 +1,85 @@
|
|||||||
extern crate proc_macro;
|
extern crate proc_macro;
|
||||||
|
|
||||||
use proc_macro::{TokenStream, TokenTree, Literal, Ident, Punct, Spacing, Group, Delimiter};
|
use syn::{Expr};
|
||||||
use std::iter::FromIterator;
|
use syn::parse::{Parse, ParseStream, Result, Error};
|
||||||
|
use syn::punctuated::{Punctuated};
|
||||||
|
use syn::{parse_macro_input, Token};
|
||||||
|
use quote::{quote, format_ident};
|
||||||
|
use proc_macro::TokenStream;
|
||||||
|
|
||||||
struct MatrixEntries {
|
struct Matrix {
|
||||||
entries: Vec<Vec<TokenTree>>
|
// Represent the matrix as a row-major vector of vectors of expressions
|
||||||
|
rows: Vec<Vec<Expr>>,
|
||||||
|
ncols: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MatrixEntries {
|
impl Matrix {
|
||||||
fn new() -> Self {
|
fn nrows(&self) -> usize {
|
||||||
Self {
|
self.rows.len()
|
||||||
entries: Vec::new()
|
}
|
||||||
|
|
||||||
|
fn ncols(&self) -> usize {
|
||||||
|
self.ncols
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_col_major_repr(&self) -> Vec<Expr> {
|
||||||
|
let mut data = Vec::with_capacity(self.nrows() * self.ncols());
|
||||||
|
for j in 0 .. self.ncols() {
|
||||||
|
for i in 0 .. self.nrows() {
|
||||||
|
data.push(self.rows[i][j].clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn begin_new_row(&mut self) {
|
type MatrixRowSyntax = Punctuated<Expr, Token![,]>;
|
||||||
self.entries.push(Vec::new());
|
type MatrixSyntax = Punctuated<MatrixRowSyntax, Token![;]>;
|
||||||
}
|
|
||||||
|
|
||||||
fn push_entry(&mut self, entry: TokenTree) {
|
impl Parse for Matrix {
|
||||||
if self.entries.is_empty() {
|
fn parse(input: ParseStream) -> Result<Self> {
|
||||||
self.entries.push(Vec::new());
|
let span = input.span();
|
||||||
}
|
// TODO: Handle empty matrix case
|
||||||
let mut last_row = self.entries.last_mut().unwrap();
|
let ast = MatrixSyntax::parse_separated_nonempty_with(input,
|
||||||
last_row.push(entry);
|
|input| MatrixRowSyntax::parse_separated_nonempty(input))?;
|
||||||
}
|
let ncols = ast.first().map(|row| row.len())
|
||||||
|
|
||||||
fn build_stream(&self) -> TokenStream {
|
|
||||||
let num_rows = self.entries.len();
|
|
||||||
let num_cols = self.entries.first()
|
|
||||||
.map(|first_row| first_row.len())
|
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
|
||||||
// First check that dimensions are consistent
|
let mut rows = Vec::new();
|
||||||
for (i, row) in self.entries.iter().enumerate() {
|
|
||||||
if row.len() != num_cols {
|
for row in ast {
|
||||||
panic!("Unexpected number of columns in row {}: {}. Expected {}", i, row.len(), num_cols);
|
if row.len() != ncols {
|
||||||
|
// TODO: Is this the correct span?
|
||||||
|
// Currently it returns the span corresponding to the first element in the macro
|
||||||
|
// invocation, but it would be nice if it returned the span of the first element
|
||||||
|
// in the first row that has an unexpected number of columns
|
||||||
|
return Err(Error::new(span, "Unexpected number of columns. TODO"))
|
||||||
}
|
}
|
||||||
|
rows.push(row.into_iter().collect());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut array_tokens = Vec::new();
|
Ok(Self {
|
||||||
|
rows,
|
||||||
// Collect entries in column major order
|
ncols
|
||||||
for i in 0 .. num_rows {
|
})
|
||||||
for j in 0 .. num_cols {
|
|
||||||
let entry = &self.entries[i][j];
|
|
||||||
array_tokens.push(entry.clone());
|
|
||||||
array_tokens.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let row_dim = format!("U{}", num_rows);
|
|
||||||
let col_dim = format!("U{}", num_cols);
|
|
||||||
// let imports = format!("use nalgebra::\{Matrix, {}, {}\};", row_dim, col_dim);
|
|
||||||
// let constructor = format!("Matrix::<_, {}, {}>::from_slice", row_dim, col_dim);
|
|
||||||
// let array_group = Group::new(Delimiter::Bracket, TokenStream::from_iter(array_tokens.into_iter()));
|
|
||||||
|
|
||||||
let array_stream = TokenStream::from_iter(array_tokens);
|
|
||||||
|
|
||||||
// TODO: Build this up without parsing?
|
|
||||||
format!(r"{{
|
|
||||||
nalgebra::MatrixMN::<_, nalgebra::{row_dim}, nalgebra::{col_dim}>::from_column_slice(&[
|
|
||||||
{array_tokens}
|
|
||||||
])
|
|
||||||
}}", row_dim=row_dim, col_dim=col_dim, array_tokens=array_stream.to_string()).parse().unwrap()
|
|
||||||
|
|
||||||
|
|
||||||
// let mut outer_group = Group::new(Delimiter::Brace,
|
|
||||||
//
|
|
||||||
// );
|
|
||||||
|
|
||||||
|
|
||||||
// TODO: Outer group
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// todo!()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[proc_macro]
|
#[proc_macro]
|
||||||
pub fn matrix(stream: TokenStream) -> TokenStream {
|
pub fn matrix(stream: TokenStream) -> TokenStream {
|
||||||
let mut entries = MatrixEntries::new();
|
let matrix = parse_macro_input!(stream as Matrix);
|
||||||
for tree in stream {
|
|
||||||
match tree {
|
|
||||||
// TokenTree::Ident(ident) => entries.push_entry(tree),
|
|
||||||
// TokenTree::Literal(literal) => entries.push_entry(tree),
|
|
||||||
TokenTree::Punct(punct) if punct == ';' => entries.begin_new_row(),
|
|
||||||
TokenTree::Punct(punct) if punct == ',' => {},
|
|
||||||
// TokenTree::Punct(punct) => panic!("Unexpected punctuation: '{}'", punct),
|
|
||||||
// TokenTree::Group(_) => panic!("Unexpected token group"),
|
|
||||||
_ => entries.push_entry(tree)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
entries.build_stream()
|
let dim_ident = |dim| format_ident!("U{}", dim);
|
||||||
|
let row_dim = dim_ident(matrix.nrows());
|
||||||
|
let col_dim = dim_ident(matrix.ncols());
|
||||||
|
let entries_col_major = matrix.to_col_major_repr();
|
||||||
|
|
||||||
|
// TODO: Use quote_spanned instead??
|
||||||
|
// TODO: Construct directly from array?
|
||||||
|
let output = quote! {
|
||||||
|
nalgebra::MatrixMN::<_, nalgebra::dimension::#row_dim, nalgebra::dimension::#col_dim>
|
||||||
|
::from_column_slice(&[#(#entries_col_major),*])
|
||||||
|
};
|
||||||
|
|
||||||
|
proc_macro::TokenStream::from(output)
|
||||||
}
|
}
|
@ -2,5 +2,7 @@ use nalgebra_macros::matrix;
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn basic_usage() {
|
fn basic_usage() {
|
||||||
matrix![ 1, 3; 4, 5*3];
|
matrix![ 1, 3;
|
||||||
|
4, 5*3;
|
||||||
|
3, 3];
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user