forked from M-Labs/nac3
artiq: support async rpcs
Co-authored-by: mwojcik <mw@m-labs.hk> Co-committed-by: mwojcik <mw@m-labs.hk>
This commit is contained in:
parent
5e2e77a500
commit
f2c047ba57
@ -112,10 +112,15 @@ def extern(function):
|
|||||||
register_function(function)
|
register_function(function)
|
||||||
return function
|
return function
|
||||||
|
|
||||||
def rpc(function):
|
|
||||||
"""Decorates a function declaration defined by the core device runtime."""
|
def rpc(arg=None, flags={}):
|
||||||
register_function(function)
|
"""Decorates a function or method to be executed on the host interpreter."""
|
||||||
return function
|
if arg is None:
|
||||||
|
def inner_decorator(function):
|
||||||
|
return rpc(function, flags)
|
||||||
|
return inner_decorator
|
||||||
|
register_function(arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
def kernel(function_or_method):
|
def kernel(function_or_method):
|
||||||
"""Decorates a function or method to be executed on the core device."""
|
"""Decorates a function or method to be executed on the core device."""
|
||||||
|
@ -824,6 +824,7 @@ fn rpc_codegen_callback_fn<'ctx>(
|
|||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
|
is_async: bool,
|
||||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||||
let int8 = ctx.ctx.i8_type();
|
let int8 = ctx.ctx.i8_type();
|
||||||
let int32 = ctx.ctx.i32_type();
|
let int32 = ctx.ctx.i32_type();
|
||||||
@ -932,35 +933,64 @@ fn rpc_codegen_callback_fn<'ctx>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// call
|
// call
|
||||||
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
if is_async {
|
||||||
ctx.module.add_function(
|
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
||||||
"rpc_send",
|
ctx.module.add_function(
|
||||||
ctx.ctx.void_type().fn_type(
|
"rpc_send_async",
|
||||||
&[
|
ctx.ctx.void_type().fn_type(
|
||||||
int32.into(),
|
&[
|
||||||
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
int32.into(),
|
||||||
ptr_type.ptr_type(AddressSpace::default()).into(),
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||||
],
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||||
false,
|
],
|
||||||
),
|
false,
|
||||||
None,
|
),
|
||||||
)
|
None,
|
||||||
});
|
)
|
||||||
ctx.builder
|
});
|
||||||
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
ctx.builder
|
||||||
.unwrap();
|
.build_call(
|
||||||
|
rpc_send_async,
|
||||||
|
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
||||||
|
"rpc.send",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
} else {
|
||||||
|
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
||||||
|
ctx.module.add_function(
|
||||||
|
"rpc_send",
|
||||||
|
ctx.ctx.void_type().fn_type(
|
||||||
|
&[
|
||||||
|
int32.into(),
|
||||||
|
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
ctx.builder
|
||||||
|
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
// reclaim stack space used by arguments
|
// reclaim stack space used by arguments
|
||||||
call_stackrestore(ctx, stackptr);
|
call_stackrestore(ctx, stackptr);
|
||||||
|
|
||||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
if is_async {
|
||||||
|
// async RPCs do not return any values
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||||
|
|
||||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||||
// An RPC returning an NDArray would not touch here.
|
// An RPC returning an NDArray would not touch here.
|
||||||
call_stackrestore(ctx, stackptr);
|
call_stackrestore(ctx, stackptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn attributes_writeback(
|
pub fn attributes_writeback(
|
||||||
@ -1055,7 +1085,7 @@ pub fn attributes_writeback(
|
|||||||
let args: Vec<_> =
|
let args: Vec<_> =
|
||||||
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||||
if let Err(e) =
|
if let Err(e) =
|
||||||
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
|
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, false)
|
||||||
{
|
{
|
||||||
return Ok(Err(e));
|
return Ok(Err(e));
|
||||||
}
|
}
|
||||||
@ -1065,9 +1095,9 @@ pub fn attributes_writeback(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rpc_codegen_callback() -> Arc<GenCall> {
|
pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
|
||||||
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
|
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
|
rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async)
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,12 +34,12 @@ use nac3core::inkwell::{
|
|||||||
targets::*,
|
targets::*,
|
||||||
OptimizationLevel,
|
OptimizationLevel,
|
||||||
};
|
};
|
||||||
use nac3core::nac3parser::{
|
|
||||||
ast::{ExprKind, Stmt, StmtKind, StrRef},
|
|
||||||
parser::parse_program,
|
|
||||||
};
|
|
||||||
use nac3core::toplevel::builtins::get_exn_constructor;
|
use nac3core::toplevel::builtins::get_exn_constructor;
|
||||||
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
|
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
|
||||||
|
use nac3core::nac3parser::{
|
||||||
|
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
|
||||||
|
parser::parse_program,
|
||||||
|
};
|
||||||
use pyo3::create_exception;
|
use pyo3::create_exception;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
|
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
|
||||||
@ -194,10 +194,8 @@ impl Nac3 {
|
|||||||
body.retain(|stmt| {
|
body.retain(|stmt| {
|
||||||
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
||||||
decorator_list.iter().any(|decorator| {
|
decorator_list.iter().any(|decorator| {
|
||||||
if let ExprKind::Name { id, .. } = decorator.node {
|
if let Some(id) = decorator_id_string(decorator) {
|
||||||
id.to_string() == "kernel"
|
id == "kernel" || id == "portable" || id == "rpc"
|
||||||
|| id.to_string() == "portable"
|
|
||||||
|| id.to_string() == "rpc"
|
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@ -210,9 +208,8 @@ impl Nac3 {
|
|||||||
}
|
}
|
||||||
StmtKind::FunctionDef { ref decorator_list, .. } => {
|
StmtKind::FunctionDef { ref decorator_list, .. } => {
|
||||||
decorator_list.iter().any(|decorator| {
|
decorator_list.iter().any(|decorator| {
|
||||||
if let ExprKind::Name { id, .. } = decorator.node {
|
if let Some(id) = decorator_id_string(decorator) {
|
||||||
let id = id.to_string();
|
id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
|
||||||
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
|
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@ -478,9 +475,25 @@ impl Nac3 {
|
|||||||
|
|
||||||
match &stmt.node {
|
match &stmt.node {
|
||||||
StmtKind::FunctionDef { decorator_list, .. } => {
|
StmtKind::FunctionDef { decorator_list, .. } => {
|
||||||
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
if decorator_list
|
||||||
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
|
.iter()
|
||||||
rpc_ids.push((None, def_id));
|
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
|
||||||
|
{
|
||||||
|
store_fun
|
||||||
|
.call1(
|
||||||
|
py,
|
||||||
|
(
|
||||||
|
def_id.0.into_py(py),
|
||||||
|
module.getattr(py, name.to_string().as_str()).unwrap(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let is_async = decorator_list.iter().any(|decorator| {
|
||||||
|
decorator_get_flags(decorator)
|
||||||
|
.iter()
|
||||||
|
.any(|constant| *constant == Constant::Str("async".into()))
|
||||||
|
});
|
||||||
|
rpc_ids.push((None, def_id, is_async));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StmtKind::ClassDef { name, body, .. } => {
|
StmtKind::ClassDef { name, body, .. } => {
|
||||||
@ -488,19 +501,26 @@ impl Nac3 {
|
|||||||
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
|
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
|
||||||
for stmt in body {
|
for stmt in body {
|
||||||
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
||||||
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
if decorator_list.iter().any(|decorator| {
|
||||||
|
decorator_id_string(decorator) == Some("rpc".to_string())
|
||||||
|
}) {
|
||||||
|
let is_async = decorator_list.iter().any(|decorator| {
|
||||||
|
decorator_get_flags(decorator)
|
||||||
|
.iter()
|
||||||
|
.any(|constant| *constant == Constant::Str("async".into()))
|
||||||
|
});
|
||||||
if name == &"__init__".into() {
|
if name == &"__init__".into() {
|
||||||
return Err(CompileError::new_err(format!(
|
return Err(CompileError::new_err(format!(
|
||||||
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
|
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
|
||||||
class_name, stmt.location
|
class_name, stmt.location
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
|
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => ()
|
_ => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = *name_to_pyid.get(&name).unwrap();
|
let id = *name_to_pyid.get(&name).unwrap();
|
||||||
@ -596,13 +616,12 @@ impl Nac3 {
|
|||||||
let top_level = Arc::new(composer.make_top_level_context());
|
let top_level = Arc::new(composer.make_top_level_context());
|
||||||
|
|
||||||
{
|
{
|
||||||
let rpc_codegen = rpc_codegen_callback();
|
|
||||||
let defs = top_level.definitions.read();
|
let defs = top_level.definitions.read();
|
||||||
for (class_data, id) in &rpc_ids {
|
for (class_data, id, is_async) in &rpc_ids {
|
||||||
let mut def = defs[id.0].write();
|
let mut def = defs[id.0].write();
|
||||||
match &mut *def {
|
match &mut *def {
|
||||||
TopLevelDef::Function { codegen_callback, .. } => {
|
TopLevelDef::Function { codegen_callback, .. } => {
|
||||||
*codegen_callback = Some(rpc_codegen.clone());
|
*codegen_callback = Some(rpc_codegen_callback(*is_async));
|
||||||
}
|
}
|
||||||
TopLevelDef::Class { methods, .. } => {
|
TopLevelDef::Class { methods, .. } => {
|
||||||
let (class_def, method_name) = class_data.as_ref().unwrap();
|
let (class_def, method_name) = class_data.as_ref().unwrap();
|
||||||
@ -613,7 +632,7 @@ impl Nac3 {
|
|||||||
if let TopLevelDef::Function { codegen_callback, .. } =
|
if let TopLevelDef::Function { codegen_callback, .. } =
|
||||||
&mut *defs[id.0].write()
|
&mut *defs[id.0].write()
|
||||||
{
|
{
|
||||||
*codegen_callback = Some(rpc_codegen.clone());
|
*codegen_callback = Some(rpc_codegen_callback(*is_async));
|
||||||
store_fun
|
store_fun
|
||||||
.call1(
|
.call1(
|
||||||
py,
|
py,
|
||||||
@ -844,6 +863,41 @@ impl Nac3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
|
||||||
|
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
|
||||||
|
if let ExprKind::Name { id, .. } = decorator.node {
|
||||||
|
// Bare decorator
|
||||||
|
return Some(id.to_string());
|
||||||
|
} else if let ExprKind::Call { func, .. } = &decorator.node {
|
||||||
|
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
|
||||||
|
// need to extract the id from within.
|
||||||
|
if let ExprKind::Name { id, .. } = func.node {
|
||||||
|
return Some(id.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves flags from a decorator, if any.
|
||||||
|
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
|
||||||
|
let mut flags = vec![];
|
||||||
|
if let ExprKind::Call { keywords, .. } = &decorator.node {
|
||||||
|
for keyword in keywords {
|
||||||
|
if keyword.node.arg != Some("flags".into()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let ExprKind::Set { elts } = &keyword.node.value.node {
|
||||||
|
for elt in elts {
|
||||||
|
if let ExprKind::Constant { value, .. } = &elt.node {
|
||||||
|
flags.push(value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flags
|
||||||
|
}
|
||||||
|
|
||||||
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
|
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
|
||||||
let linker_args = vec![
|
let linker_args = vec![
|
||||||
"-shared".to_string(),
|
"-shared".to_string(),
|
||||||
|
@ -1894,7 +1894,8 @@ impl TopLevelComposer {
|
|||||||
} = &mut *function_def
|
} = &mut *function_def
|
||||||
{
|
{
|
||||||
let signature_ty_enum = unifier.get_ty(*signature);
|
let signature_ty_enum = unifier.get_ty(*signature);
|
||||||
let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref()
|
let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) =
|
||||||
|
signature_ty_enum.as_ref()
|
||||||
else {
|
else {
|
||||||
unreachable!("must be typeenum::tfunc")
|
unreachable!("must be typeenum::tfunc")
|
||||||
};
|
};
|
||||||
@ -2057,6 +2058,16 @@ impl TopLevelComposer {
|
|||||||
instance_to_symbol.insert(String::new(), simple_name.to_string());
|
instance_to_symbol.insert(String::new(), simple_name.to_string());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if !decorator_list.is_empty() {
|
||||||
|
if let ast::ExprKind::Call { func, .. } = &decorator_list[0].node {
|
||||||
|
if matches!(&func.node,
|
||||||
|
ast::ExprKind::Name{ id, .. } if id == &"rpc".into())
|
||||||
|
{
|
||||||
|
instance_to_symbol.insert(String::new(), simple_name.to_string());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let fun_body = body
|
let fun_body = body
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
Loading…
Reference in New Issue
Block a user