1
0
forked from M-Labs/nac3

rpc: support async funcoption

This commit is contained in:
mwojcik 2024-09-11 12:23:01 +08:00
parent 219af79017
commit 4c1e4b8fba
3 changed files with 59 additions and 34 deletions

View File

@ -12,7 +12,7 @@ use nac3core::{
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall}, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{iter_type_vars, FunSignature, FuncOption, FuncArg, Type, TypeEnum, VarMap},
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
@ -831,8 +831,6 @@ fn rpc_codegen_callback_fn<'ctx>(
let ptr_type = int8.ptr_type(AddressSpace::default()); let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
// println!("obj: {:?}", obj);
println!("fun: {:?}", fun);
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
@ -934,7 +932,28 @@ fn rpc_codegen_callback_fn<'ctx>(
ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
} }
let is_async = fun.0.opts.iter().any(|opt| FuncOption::Async == *opt);
// call // call
if is_async {
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send_async",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(rpc_send_async, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap();
} else {
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"rpc_send", "rpc_send",
@ -952,10 +971,15 @@ fn rpc_codegen_callback_fn<'ctx>(
ctx.builder ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap(); .unwrap();
}
// reclaim stack space used by arguments // reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
if is_async {
// async RPCs do not return any values
Ok(None)
} else {
let result = format_rpc_ret(generator, ctx, fun.0.ret); let result = format_rpc_ret(generator, ctx, fun.0.ret);
if !result.is_some_and(|res| res.get_type().is_pointer_type()) { if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
@ -964,6 +988,7 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
Ok(result) Ok(result)
}
} }
pub fn attributes_writeback( pub fn attributes_writeback(

View File

@ -54,7 +54,7 @@ use nac3core::{
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer}, composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef, DefinitionId, GenCall, TopLevelDef,
}, },
typecheck::typedef::{FunSignature, FuncArg}, typecheck::typedef::{FunSignature, FuncOption, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type}, typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
}; };
@ -200,7 +200,7 @@ impl Nac3 {
|| id.to_string() == "rpc" || id.to_string() == "rpc"
} else if let ExprKind::Call { func, .. } = &decorator.node { } else if let ExprKind::Call { func, .. } = &decorator.node {
// decorators with flags (e.g. rpc async) have Call for the node; // decorators with flags (e.g. rpc async) have Call for the node;
// this is to remove the middle // this is to remove the middle part
if let ExprKind::Name { id, .. } = func.node { if let ExprKind::Name { id, .. } = func.node {
if id.to_string() == "rpc" { if id.to_string() == "rpc" {
println!("found rpc: {:?}", func); println!("found rpc: {:?}", func);
@ -513,7 +513,7 @@ impl Nac3 {
class_name, stmt.location class_name, stmt.location
))); )));
} }
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, Some(FuncFlags::Async))); rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
} }
} }
} }
@ -616,7 +616,7 @@ impl Nac3 {
{ {
let rpc_codegen = rpc_codegen_callback(); let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
for (class_data, id, flags) in &rpc_ids { for (class_data, id) in &rpc_ids {
let mut def = defs[id.0].write(); let mut def = defs[id.0].write();
match &mut *def { match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => { TopLevelDef::Function { codegen_callback, .. } => {

View File

@ -123,8 +123,8 @@ impl FuncArg {
} }
} }
#[derive(Debug, Clone)] #[derive(PartialEq, Debug, Clone)]
pub enum FunOption { pub enum FuncOption {
Async, Async,
} }
@ -133,7 +133,7 @@ pub struct FunSignature {
pub args: Vec<FuncArg>, pub args: Vec<FuncArg>,
pub ret: Type, pub ret: Type,
pub vars: VarMap, pub vars: VarMap,
pub opts: Vec<FunOption>, pub opts: Vec<FuncOption>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]