From a9b341dfc6610a092d570579b63ab75b72d81e20 Mon Sep 17 00:00:00 2001 From: mwojcik Date: Fri, 6 Sep 2024 16:33:58 +0800 Subject: [PATCH 1/2] artiq: support async rpcs --- nac3artiq/demo/min_artiq.py | 13 +++-- nac3artiq/src/codegen.rs | 83 ++++++++++++++++++++----------- nac3artiq/src/lib.rs | 79 ++++++++++++++++++++++------- nac3core/src/toplevel/composer.rs | 13 ++++- 4 files changed, 136 insertions(+), 52 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index d471032a..3840a57a 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -112,10 +112,15 @@ def extern(function): register_function(function) return function -def rpc(function): - """Decorates a function declaration defined by the core device runtime.""" - register_function(function) - return function + +def rpc(arg=None, flags={}): + """Decorates a function or method to be executed on the host interpreter.""" + if arg is None: + def inner_decorator(function): + return rpc(function, flags) + return inner_decorator + register_function(arg) + return arg def kernel(function_or_method): """Decorates a function or method to be executed on the core device.""" diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 1646c1c6..4d471050 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -824,6 +824,7 @@ fn rpc_codegen_callback_fn<'ctx>( fun: (&FunSignature, DefinitionId), args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, + is_async: bool, ) -> Result>, String> { let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); @@ -831,6 +832,7 @@ fn rpc_codegen_callback_fn<'ctx>( let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); + let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); @@ -932,35 +934,60 @@ fn rpc_codegen_callback_fn<'ctx>( } // call - let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { - ctx.module.add_function( - "rpc_send", - ctx.ctx.void_type().fn_type( - &[ - int32.into(), - tag_ptr_type.ptr_type(AddressSpace::default()).into(), - ptr_type.ptr_type(AddressSpace::default()).into(), - ], - false, - ), - None, - ) - }); - ctx.builder - .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") - .unwrap(); + if is_async { + let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { + ctx.module.add_function( + "rpc_send_async", + ctx.ctx.void_type().fn_type( + &[ + int32.into(), + tag_ptr_type.ptr_type(AddressSpace::default()).into(), + ptr_type.ptr_type(AddressSpace::default()).into(), + ], + false, + ), + None, + ) + }); + ctx.builder + .build_call(rpc_send_async, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") + .unwrap(); + } else { + let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { + ctx.module.add_function( + "rpc_send", + ctx.ctx.void_type().fn_type( + &[ + int32.into(), + tag_ptr_type.ptr_type(AddressSpace::default()).into(), + ptr_type.ptr_type(AddressSpace::default()).into(), + ], + false, + ), + None, + ) + }); + ctx.builder + .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") + .unwrap(); + } // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); - let result = format_rpc_ret(generator, ctx, fun.0.ret); - - if !result.is_some_and(|res| res.get_type().is_pointer_type()) { - // An RPC returning an NDArray would not touch here. - call_stackrestore(ctx, stackptr); + if is_async { + // async RPCs do not return any values + Ok(None) + } else { + let result = format_rpc_ret(generator, ctx, fun.0.ret); + + if !result.is_some_and(|res| res.get_type().is_pointer_type()) { + // An RPC returning an NDArray would not touch here. + call_stackrestore(ctx, stackptr); + } + + Ok(result) } - - Ok(result) } pub fn attributes_writeback( @@ -1055,7 +1082,7 @@ pub fn attributes_writeback( let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); if let Err(e) = - rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator) + rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, false) { return Ok(Err(e)); } @@ -1065,9 +1092,9 @@ pub fn attributes_writeback( Ok(()) } -pub fn rpc_codegen_callback() -> Arc { - Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| { - rpc_codegen_callback_fn(ctx, obj, fun, args, generator) +pub fn rpc_codegen_callback(is_async: bool) -> Arc { + Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { + rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async) }))) } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9675efb9..b5e6dab3 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -34,8 +34,12 @@ use nac3core::inkwell::{ targets::*, OptimizationLevel, }; -use nac3core::nac3parser::{ - ast::{ExprKind, Stmt, StmtKind, StrRef}, +use itertools::Itertools; +use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions}; +use nac3core::toplevel::builtins::get_exn_constructor; +use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; +use nac3parser::{ + ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }; use nac3core::toplevel::builtins::get_exn_constructor; @@ -135,7 +139,7 @@ struct Nac3 { string_store: Arc>>, exception_ids: Arc>>, deferred_eval_store: DeferredEvaluationStore, - /// LLVM-related options for code generation. + /// LLVM-related options for code generzation. llvm_options: CodeGenLLVMOptions, } @@ -194,10 +198,8 @@ impl Nac3 { body.retain(|stmt| { if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { decorator_list.iter().any(|decorator| { - if let ExprKind::Name { id, .. } = decorator.node { - id.to_string() == "kernel" - || id.to_string() == "portable" - || id.to_string() == "rpc" + if let Some(id) = decorator_id_string(decorator) { + id == "kernel" || id == "portable" || id == "rpc" } else { false } @@ -210,9 +212,8 @@ impl Nac3 { } StmtKind::FunctionDef { ref decorator_list, .. } => { decorator_list.iter().any(|decorator| { - if let ExprKind::Name { id, .. } = decorator.node { - let id = id.to_string(); - id == "extern" || id == "portable" || id == "kernel" || id == "rpc" + if let Some(id) = decorator_id_string(decorator) { + id == "extern" || id == "kernel" || id == "portable" || id == "rpc" } else { false } @@ -478,9 +479,13 @@ impl Nac3 { match &stmt.node { StmtKind::FunctionDef { decorator_list, .. } => { - if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { + if decorator_list.iter().any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string())) { store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); - rpc_ids.push((None, def_id)); + let is_async = decorator_list.iter().any( + |decorator| decorator_get_flags(decorator).iter().any( + |constant| *constant == Constant::Str("async".into()) + )); + rpc_ids.push((None, def_id, is_async)); } } StmtKind::ClassDef { name, body, .. } => { @@ -488,14 +493,18 @@ impl Nac3 { let class_obj = module.getattr(py, class_name.as_str()).unwrap(); for stmt in body { if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { - if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { + if decorator_list.iter().any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string())) { + let is_async = decorator_list.iter().any( + |decorator| decorator_get_flags(decorator).iter().any( + |constant| *constant == Constant::Str("async".into()) + )); if name == &"__init__".into() { return Err(CompileError::new_err(format!( "compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})", class_name, stmt.location ))); } - rpc_ids.push((Some((class_obj.clone(), *name)), def_id)); + rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async)); } } } @@ -596,13 +605,12 @@ impl Nac3 { let top_level = Arc::new(composer.make_top_level_context()); { - let rpc_codegen = rpc_codegen_callback(); let defs = top_level.definitions.read(); - for (class_data, id) in &rpc_ids { + for (class_data, id, is_async) in &rpc_ids { let mut def = defs[id.0].write(); match &mut *def { TopLevelDef::Function { codegen_callback, .. } => { - *codegen_callback = Some(rpc_codegen.clone()); + *codegen_callback = Some(rpc_codegen_callback(*is_async)); } TopLevelDef::Class { methods, .. } => { let (class_def, method_name) = class_data.as_ref().unwrap(); @@ -613,7 +621,7 @@ impl Nac3 { if let TopLevelDef::Function { codegen_callback, .. } = &mut *defs[id.0].write() { - *codegen_callback = Some(rpc_codegen.clone()); + *codegen_callback = Some(rpc_codegen_callback(*is_async)); store_fun .call1( py, @@ -844,6 +852,41 @@ impl Nac3 { } } +fn decorator_id_string(decorator: &Located) -> Option { + /// Retrieves the Name.id from a decorator, supports decorators with arguments. + if let ExprKind::Name { id, .. } = decorator.node { + // Bare decorator + return Some(id.to_string()); + } else if let ExprKind::Call { func, .. } = &decorator.node { + // Decorators that are calls (e.g. "@rpc()") have Call for the node, + // need to extract the id from within. + if let ExprKind::Name { id, .. } = func.node { + return Some(id.to_string()); + } + } + None +} + +fn decorator_get_flags(decorator: &Located) -> Vec { + /// Retrieves flags from a decorator, if any. + let mut flags = vec![]; + if let ExprKind::Call { keywords, .. } = &decorator.node { + for keyword in keywords.iter() { + if keyword.node.arg != Some("flags".into()) { + continue; + } + if let ExprKind::Set { elts } = &keyword.node.value.node { + for elt in elts { + if let ExprKind::Constant { value, .. } = &elt.node { + flags.push(value.clone()); + } + } + } + } + } + flags +} + fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { let linker_args = vec![ "-shared".to_string(), diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 603a508e..98eeebe3 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1673,7 +1673,7 @@ impl TopLevelComposer { // they may be changed with our use of placeholders for (def, _) in definition_ast_list.iter().skip(self.builtin_num) { if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() { - if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = + if let TypeEnum::TFunc(FunSignature { args, ret, vars}) = unifier.get_ty(*signature).as_ref() { let new_var_ids = vars @@ -1894,7 +1894,7 @@ impl TopLevelComposer { } = &mut *function_def { let signature_ty_enum = unifier.get_ty(*signature); - let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref() + let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = signature_ty_enum.as_ref() else { unreachable!("must be typeenum::tfunc") }; @@ -2057,6 +2057,15 @@ impl TopLevelComposer { instance_to_symbol.insert(String::new(), simple_name.to_string()); continue; } + if !decorator_list.is_empty() { + if let ast::ExprKind::Call { func, .. } = &decorator_list[0].node { + if matches!(&func.node, + ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) { + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } + } + } let fun_body = body .into_iter() -- 2.44.2 From 5a5656e983e4e5398824a1e4072b53de0d56fb63 Mon Sep 17 00:00:00 2001 From: mwojcik Date: Thu, 12 Sep 2024 16:31:01 +0800 Subject: [PATCH 2/2] cargo fmt --- nac3artiq/src/codegen.rs | 11 ++++--- nac3artiq/src/lib.rs | 53 +++++++++++++++++++------------ nac3core/src/toplevel/composer.rs | 16 ++++++---- 3 files changed, 48 insertions(+), 32 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 4d471050..62dbda11 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -832,7 +832,6 @@ fn rpc_codegen_callback_fn<'ctx>( let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); - let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); @@ -950,7 +949,11 @@ fn rpc_codegen_callback_fn<'ctx>( ) }); ctx.builder - .build_call(rpc_send_async, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") + .build_call( + rpc_send_async, + &[service_id.into(), tag_ptr.into(), args_ptr.into()], + "rpc.send", + ) .unwrap(); } else { let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { @@ -980,12 +983,12 @@ fn rpc_codegen_callback_fn<'ctx>( Ok(None) } else { let result = format_rpc_ret(generator, ctx, fun.0.ret); - + if !result.is_some_and(|res| res.get_type().is_pointer_type()) { // An RPC returning an NDArray would not touch here. call_stackrestore(ctx, stackptr); } - + Ok(result) } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index b5e6dab3..c9009ea3 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -34,16 +34,12 @@ use nac3core::inkwell::{ targets::*, OptimizationLevel, }; -use itertools::Itertools; -use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions}; use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; -use nac3parser::{ +use nac3core::nac3parser::{ ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }; -use nac3core::toplevel::builtins::get_exn_constructor; -use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; use pyo3::create_exception; use pyo3::prelude::*; use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; @@ -139,7 +135,7 @@ struct Nac3 { string_store: Arc>>, exception_ids: Arc>>, deferred_eval_store: DeferredEvaluationStore, - /// LLVM-related options for code generzation. + /// LLVM-related options for code generation. llvm_options: CodeGenLLVMOptions, } @@ -479,12 +475,24 @@ impl Nac3 { match &stmt.node { StmtKind::FunctionDef { decorator_list, .. } => { - if decorator_list.iter().any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string())) { - store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); - let is_async = decorator_list.iter().any( - |decorator| decorator_get_flags(decorator).iter().any( - |constant| *constant == Constant::Str("async".into()) - )); + if decorator_list + .iter() + .any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string())) + { + store_fun + .call1( + py, + ( + def_id.0.into_py(py), + module.getattr(py, name.to_string().as_str()).unwrap(), + ), + ) + .unwrap(); + let is_async = decorator_list.iter().any(|decorator| { + decorator_get_flags(decorator) + .iter() + .any(|constant| *constant == Constant::Str("async".into())) + }); rpc_ids.push((None, def_id, is_async)); } } @@ -493,11 +501,14 @@ impl Nac3 { let class_obj = module.getattr(py, class_name.as_str()).unwrap(); for stmt in body { if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { - if decorator_list.iter().any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string())) { - let is_async = decorator_list.iter().any( - |decorator| decorator_get_flags(decorator).iter().any( - |constant| *constant == Constant::Str("async".into()) - )); + if decorator_list.iter().any(|decorator| { + decorator_id_string(decorator) == Some("rpc".to_string()) + }) { + let is_async = decorator_list.iter().any(|decorator| { + decorator_get_flags(decorator) + .iter() + .any(|constant| *constant == Constant::Str("async".into())) + }); if name == &"__init__".into() { return Err(CompileError::new_err(format!( "compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})", @@ -509,7 +520,7 @@ impl Nac3 { } } } - _ => () + _ => (), } let id = *name_to_pyid.get(&name).unwrap(); @@ -852,8 +863,8 @@ impl Nac3 { } } +/// Retrieves the Name.id from a decorator, supports decorators with arguments. fn decorator_id_string(decorator: &Located) -> Option { - /// Retrieves the Name.id from a decorator, supports decorators with arguments. if let ExprKind::Name { id, .. } = decorator.node { // Bare decorator return Some(id.to_string()); @@ -867,11 +878,11 @@ fn decorator_id_string(decorator: &Located) -> Option { None } +/// Retrieves flags from a decorator, if any. fn decorator_get_flags(decorator: &Located) -> Vec { - /// Retrieves flags from a decorator, if any. let mut flags = vec![]; if let ExprKind::Call { keywords, .. } = &decorator.node { - for keyword in keywords.iter() { + for keyword in keywords { if keyword.node.arg != Some("flags".into()) { continue; } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 98eeebe3..125e137a 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1673,7 +1673,7 @@ impl TopLevelComposer { // they may be changed with our use of placeholders for (def, _) in definition_ast_list.iter().skip(self.builtin_num) { if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() { - if let TypeEnum::TFunc(FunSignature { args, ret, vars}) = + if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = unifier.get_ty(*signature).as_ref() { let new_var_ids = vars @@ -1894,7 +1894,8 @@ impl TopLevelComposer { } = &mut *function_def { let signature_ty_enum = unifier.get_ty(*signature); - let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = signature_ty_enum.as_ref() + let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = + signature_ty_enum.as_ref() else { unreachable!("must be typeenum::tfunc") }; @@ -2059,11 +2060,12 @@ impl TopLevelComposer { } if !decorator_list.is_empty() { if let ast::ExprKind::Call { func, .. } = &decorator_list[0].node { - if matches!(&func.node, - ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) { - instance_to_symbol.insert(String::new(), simple_name.to_string()); - continue; - } + if matches!(&func.node, + ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) + { + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } } } -- 2.44.2