nalgebra/nalgebra-macros/src/lib.rs

133 lines
4.1 KiB
Rust
Raw Normal View History

extern crate proc_macro;
2021-04-11 17:08:41 +08:00
use syn::{Expr};
use syn::parse::{Parse, ParseStream, Result, Error};
use syn::punctuated::{Punctuated};
use syn::{parse_macro_input, Token};
2021-04-29 23:06:52 +08:00
use quote::{quote, TokenStreamExt, ToTokens};
2021-04-11 17:08:41 +08:00
use proc_macro::TokenStream;
use proc_macro2::{TokenStream as TokenStream2, Delimiter, TokenTree, Spacing};
use proc_macro2::{Group, Punct};
2021-04-11 17:08:41 +08:00
struct Matrix {
// Represent the matrix as a row-major vector of vectors of expressions
rows: Vec<Vec<Expr>>,
ncols: usize,
}
2021-04-11 17:08:41 +08:00
impl Matrix {
fn nrows(&self) -> usize {
self.rows.len()
}
2021-04-11 17:08:41 +08:00
fn ncols(&self) -> usize {
self.ncols
}
2021-04-29 23:06:52 +08:00
/// Produces a stream of tokens representing this matrix as a column-major nested array.
fn to_col_major_nested_array_tokens(&self) -> TokenStream2 {
let mut result = TokenStream2::new();
2021-04-11 17:08:41 +08:00
for j in 0 .. self.ncols() {
let mut col = TokenStream2::new();
let col_iter = (0 .. self.nrows())
.map(move |i| &self.rows[i][j]);
col.append_separated(col_iter, Punct::new(',', Spacing::Alone));
result.append(Group::new(Delimiter::Bracket, col));
result.append(Punct::new(',', Spacing::Alone));
}
TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, result)))
}
2021-04-29 23:06:52 +08:00
/// Produces a stream of tokens representing this matrix as a column-major flat array
/// (suitable for representing e.g. a `DMatrix`).
fn to_col_major_flat_array_tokens(&self) -> TokenStream2 {
let mut data = TokenStream2::new();
for j in 0 .. self.ncols() {
for i in 0 .. self.nrows() {
self.rows[i][j].to_tokens(&mut data);
data.append(Punct::new(',', Spacing::Alone));
}
}
TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Bracket, data)))
}
2021-04-11 17:08:41 +08:00
}
2021-04-11 17:08:41 +08:00
type MatrixRowSyntax = Punctuated<Expr, Token![,]>;
2021-04-11 17:08:41 +08:00
impl Parse for Matrix {
fn parse(input: ParseStream) -> Result<Self> {
let mut rows = Vec::new();
2021-04-11 23:29:33 +08:00
let mut ncols = None;
while !input.is_empty() {
let row_span = input.span();
let row = MatrixRowSyntax::parse_separated_nonempty(input)?;
2021-04-11 23:29:33 +08:00
if let Some(ncols) = ncols {
if row.len() != ncols {
let row_idx = rows.len();
let error_msg = format!(
"Unexpected number of entries in row {}. Expected {}, found {} entries.",
row_idx,
ncols,
row.len());
return Err(Error::new(row_span, error_msg));
}
} else {
ncols = Some(row.len());
}
2021-04-11 17:08:41 +08:00
rows.push(row.into_iter().collect());
2021-04-11 23:29:33 +08:00
// We've just read a row, so if there are more tokens, there must be a semi-colon,
// otherwise the input is malformed
if !input.is_empty() {
input.parse::<Token![;]>()?;
}
}
2021-04-11 17:08:41 +08:00
Ok(Self {
rows,
2021-04-11 23:29:33 +08:00
ncols: ncols.unwrap_or(0)
2021-04-11 17:08:41 +08:00
})
}
}
#[proc_macro]
pub fn matrix(stream: TokenStream) -> TokenStream {
2021-04-11 17:08:41 +08:00
let matrix = parse_macro_input!(stream as Matrix);
let row_dim = matrix.nrows();
let col_dim = matrix.ncols();
2021-04-29 23:06:52 +08:00
let array_tokens = matrix.to_col_major_nested_array_tokens();
2021-04-11 17:08:41 +08:00
// TODO: Use quote_spanned instead??
let output = quote! {
nalgebra::SMatrix::<_, #row_dim, #col_dim>
::from_array_storage(nalgebra::ArrayStorage(#array_tokens))
2021-04-11 17:08:41 +08:00
};
2021-04-29 23:06:52 +08:00
proc_macro::TokenStream::from(output)
}
#[proc_macro]
pub fn dmatrix(stream: TokenStream) -> TokenStream {
let matrix = parse_macro_input!(stream as Matrix);
let row_dim = matrix.nrows();
let col_dim = matrix.ncols();
let array_tokens = matrix.to_col_major_flat_array_tokens();
// TODO: Use quote_spanned instead??
let output = quote! {
nalgebra::DMatrix::<_>
::from_vec_storage(nalgebra::VecStorage::new(
nalgebra::Dynamic::new(#row_dim),
nalgebra::Dynamic::new(#col_dim),
vec!#array_tokens))
};
2021-04-11 17:08:41 +08:00
proc_macro::TokenStream::from(output)
}