diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index c22990a1..4db1258e 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -37,7 +37,7 @@ use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineO use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; use nac3parser::{ - ast::{ExprKind, Stmt, StmtKind, StrRef}, + ast::{ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }; use pyo3::create_exception; @@ -135,7 +135,7 @@ struct Nac3 { string_store: Arc>>, exception_ids: Arc>>, deferred_eval_store: DeferredEvaluationStore, - /// LLVM-related options for code generation. + /// LLVM-related options for code generzation. llvm_options: CodeGenLLVMOptions, } @@ -191,30 +191,12 @@ impl Nac3 { }) .unwrap() }); - body.retain_mut(|stmt| { - if let StmtKind::FunctionDef { ref mut decorator_list, .. } = stmt.node { - decorator_list.iter_mut().any(|decorator| { - if let ExprKind::Name { id, .. } = decorator.node { - id.to_string() == "kernel" - || id.to_string() == "portable" - || id.to_string() == "rpc" - } else if let ExprKind::Call { func, .. } = &decorator.node { - // decorators with flags (e.g. rpc async) have Call for the node; - // this is to remove the middle part - if let ExprKind::Name { id, .. } = func.node { - if id.to_string() == "rpc" { - println!("found rpc: {:?}", func); - println!("decorator node: {:?}", decorator.node); - decorator.node = func.clone().node; - true + body.retain(|stmt| { + if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { + decorator_list.iter().any(|decorator| { + if let Some(id) = decorator_id_string(decorator) { + id == "kernel" || id == "portable" || id == "rpc" } else { - false - } - } else { - false - } - } - else { false } }) @@ -226,9 +208,8 @@ impl Nac3 { } StmtKind::FunctionDef { ref decorator_list, .. } => { decorator_list.iter().any(|decorator| { - if let ExprKind::Name { id, .. } = decorator.node { - let id = id.to_string(); - id == "extern" || id == "portable" || id == "kernel" || id == "rpc" + if let Some(id) = decorator_id_string(decorator) { + id == "extern" || id == "kernel" || id == "portable" || id == "rpc" } else { false } @@ -862,6 +843,19 @@ impl Nac3 { } } +fn decorator_id_string(decorator: &Located) -> Option { + if let ExprKind::Name { id, .. } = decorator.node { + return Some(id.to_string()); + } else if let ExprKind::Call { func, .. } = &decorator.node { + // decorators with flags (e.g. rpc async) have Call for the node, + // extract the id from within + if let ExprKind::Name { id, .. } = func.node { + return Some(id.to_string()); + } + } + None +} + fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { let linker_args = vec![ "-shared".to_string(),