1
0
forked from M-Labs/nac3

nac3artiq: RPC support

This commit is contained in:
pca006132 2022-02-12 21:17:37 +08:00
parent e303248261
commit bf52e294ee
6 changed files with 577 additions and 81 deletions

View File

@ -0,0 +1,57 @@
class EmbeddingMap:
def __init__(self):
self.object_inverse_map = {}
self.object_map = {}
self.string_map = {}
self.string_reverse_map = {}
self.function_map = {}
# preallocate exception names
self.preallocate_runtime_exception_names(["RuntimeError",
"RTIOUnderflow",
"RTIOOverflow",
"RTIODestinationUnreachable",
"DMAError",
"I2CError",
"CacheError",
"SPIError",
"0:ZeroDivisionError",
"0:IndexError"])
def preallocate_runtime_exception_names(self, names):
for i, name in enumerate(names):
if ":" not in name:
name = "0:artiq.coredevice.exceptions." + name
exn_id = self.store_str(name)
assert exn_id == i
def store_function(self, key, fun):
self.function_map[key] = fun
return key
def store_object(self, obj):
obj_id = id(obj)
if obj_id in self.object_inverse_map:
return self.object_inverse_map[obj_id]
key = len(self.object_map)
self.object_map[key] = obj
self.object_inverse_map[obj_id] = key
return key
def store_str(self, s):
if s in self.string_reverse_map:
return self.string_reverse_map[s]
key = len(self.string_map)
self.string_map[key] = s
self.string_reverse_map[s] = key
return key
def retrieve_function(self, key):
return self.function_map[key]
def retrieve_object(self, key):
return self.object_map[key]
def retrieve_str(self, key):
return self.string_map[key]

View File

@ -6,13 +6,14 @@ from typing import Generic, TypeVar
from math import floor, ceil
import nac3artiq
from embedding_map import EmbeddingMap
__all__ = [
"Kernel", "KernelInvariant", "virtual",
"round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3",
"ms", "us", "ns",
"rpc", "ms", "us", "ns",
"print_int32", "print_int64",
"Core", "TTLOut",
"parallel", "sequential"
@ -65,6 +66,10 @@ def extern(function):
register_function(function)
return function
def rpc(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""
@ -146,6 +151,9 @@ class Core:
def run(self, method, *args, **kwargs):
global allow_registration
embedding = EmbeddingMap()
if allow_registration:
compiler.analyze(registered_functions, registered_classes)
allow_registration = False
@ -157,7 +165,7 @@ class Core:
obj = method
name = ""
compiler.compile_method_to_file(obj, name, args, "module.elf")
compiler.compile_method_to_file(obj, name, args, "module.elf", embedding)
@kernel
def reset(self):

View File

@ -1,16 +1,30 @@
use nac3core::{
codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator},
codegen::{
expr::gen_call,
stmt::{gen_block, gen_with},
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::DefinitionId,
toplevel::{DefinitionId, GenCall},
typecheck::typedef::{FunSignature, Type},
};
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{context::Context, types::IntType, values::BasicValueEnum};
use inkwell::{
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
};
use crate::timeline::TimeFns;
use std::{
collections::hash_map::DefaultHasher,
collections::HashMap,
convert::TryInto,
hash::{Hash, Hasher},
sync::Arc,
};
pub struct ArtiqCodeGenerator<'a> {
name: String,
size_t: u32,
@ -21,16 +35,13 @@ pub struct ArtiqCodeGenerator<'a> {
}
impl<'a> ArtiqCodeGenerator<'a> {
pub fn new(name: String, size_t: u32, timeline: &'a (dyn TimeFns + Sync)) -> ArtiqCodeGenerator<'a> {
pub fn new(
name: String,
size_t: u32,
timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> {
assert!(size_t == 32 || size_t == 64);
ArtiqCodeGenerator {
name,
size_t,
name_counter: 0,
start: None,
end: None,
timeline,
}
ArtiqCodeGenerator { name, size_t, name_counter: 0, start: None, end: None, timeline }
}
}
@ -86,7 +97,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>,
) -> bool {
) {
if let StmtKind::With { items, body, .. } = &stmt.node {
if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0];
@ -108,9 +119,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let old_start = self.start.take();
let old_end = self.end.take();
let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)
.unwrap()
.to_basic_value_enum(ctx, self)
self.gen_expr(ctx, old_start).unwrap().to_basic_value_enum(ctx, self)
} else {
self.timeline.emit_now_mu(ctx)
};
@ -126,10 +135,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let start_expr = Located {
// location does not matter at this point
location: stmt.location,
node: ExprKind::Name {
id: start,
ctx: name_ctx.clone(),
},
node: ExprKind::Name { id: start, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64),
};
let start = self.gen_store_target(ctx, &start_expr);
@ -140,40 +146,41 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let end_expr = Located {
// location does not matter at this point
location: stmt.location,
node: ExprKind::Name {
id: end,
ctx: name_ctx.clone(),
},
node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64),
};
let end = self.gen_store_target(ctx, &end_expr);
ctx.builder.build_store(end, now);
self.end = Some(end_expr);
self.name_counter += 1;
let mut exited = false;
for stmt in body.iter() {
if self.gen_stmt(ctx, stmt) {
exited = true;
break;
}
}
gen_block(self, ctx, body.iter());
let current = ctx.builder.get_insert_block().unwrap();
// if the current block is terminated, move before the terminator
// we want to set the timeline before reaching the terminator
// TODO: This may be unsound if there are multiple exit paths in the
// block... e.g.
// if ...:
// return
// Perhaps we can fix this by using actual with block?
let reset_position = if let Some(terminator) = current.get_terminator() {
ctx.builder.position_before(&terminator);
true
} else {
false
};
// set duration
let end_expr = self.end.take().unwrap();
let end_val = self
.gen_expr(ctx, &end_expr)
.unwrap()
.to_basic_value_enum(ctx, self);
let end_val =
self.gen_expr(ctx, &end_expr).unwrap().to_basic_value_enum(ctx, self);
// inside an sequential block
// inside a sequential block
if old_start.is_none() {
self.timeline.emit_at_mu(ctx, end_val);
}
// inside a parallel block, should update the outer max now_mu
if let Some(old_end) = &old_end {
let outer_end_val = self
.gen_expr(ctx, old_end)
.unwrap()
.to_basic_value_enum(ctx, self);
let outer_end_val =
self.gen_expr(ctx, old_end).unwrap().to_basic_value_enum(ctx, self);
let smax =
ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
let i64 = ctx.ctx.i64_type();
@ -194,24 +201,294 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
}
self.start = old_start;
self.end = old_end;
return exited;
if reset_position {
ctx.builder.position_at_end(current);
}
return;
} else if id == &"sequential".into() {
let start = self.start.take();
for stmt in body.iter() {
if self.gen_stmt(ctx, stmt) {
self.start = start;
return true;
self.gen_stmt(ctx, stmt);
if ctx.is_terminated() {
break;
}
}
self.start = start;
return false;
return
}
}
}
// not parallel/sequential
gen_with(self, ctx, stmt)
gen_with(self, ctx, stmt);
} else {
unreachable!()
}
}
}
fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec<u8>) {
use nac3core::typecheck::typedef::TypeEnum::*;
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let float = ctx.primitives.float;
let bool = ctx.primitives.bool;
let str = ctx.primitives.str;
let none = ctx.primitives.none;
if ctx.unifier.unioned(ty, int32) {
buffer.push(b'i');
} else if ctx.unifier.unioned(ty, int64) {
buffer.push(b'I');
} else if ctx.unifier.unioned(ty, float) {
buffer.push(b'f');
} else if ctx.unifier.unioned(ty, bool) {
buffer.push(b'b');
} else if ctx.unifier.unioned(ty, str) {
buffer.push(b's');
} else if ctx.unifier.unioned(ty, none) {
buffer.push(b'n');
} else {
let ty = ctx.unifier.get_ty(ty);
match &*ty {
TTuple { ty } => {
buffer.push(b't');
buffer.push(ty.len() as u8);
for ty in ty {
gen_rpc_tag(ctx, *ty, buffer);
}
}
TList { ty } => {
buffer.push(b'l');
gen_rpc_tag(ctx, *ty, buffer);
}
// we should return an error, this will be fixed after improving error message
// as this requires returning an error during codegen
_ => unimplemented!(),
}
}
}
fn rpc_codegen_callback_fn<'ctx, 'a>(
ctx: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Option<BasicValueEnum<'ctx>> {
let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic);
let size_type = generator.get_size_type(ctx.ctx);
let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1.0 as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
tag.push(b'O');
}
for arg in fun.0.args.iter() {
gen_rpc_tag(ctx, arg.ty, &mut tag);
}
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, &mut tag);
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish());
let tag_ptr = ctx
.module
.get_global(hash.as_str())
.unwrap_or_else(|| {
let tag_arr_ptr = ctx.module.add_global(
int8.array_type(tag.len() as u32),
None,
format!("tagptr{}", fun.1 .0).as_str(),
);
tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(),
));
tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(),
],
false,
));
tag_ptr
})
.as_pointer_value();
let arg_length = args.len() + if obj.is_some() { 1 } else { 0 };
let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| {
ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None)
});
let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| {
ctx.module.add_function(
"llvm.stackrestore",
ctx.ctx.void_type().fn_type(&[ptr_type.into()], false),
None,
)
});
let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack");
let args_ptr = ctx.builder.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
);
// -- rpc args handling
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in args.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into());
}
// reorder the parameters
let mut real_params = fun
.0
.args
.iter()
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator))
.collect::<Vec<_>>();
if let Some(obj) = obj {
if let ValueEnum::Static(obj) = obj.1 {
real_params.insert(0, obj.get_const_obj(ctx, generator));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, arg) in real_params.iter().enumerate() {
let arg_slot = if arg.is_pointer_value() {
arg.into_pointer_value()
} else {
let arg_slot = ctx.builder.build_alloca(arg.get_type(), &format!("rpc.arg{}", i));
ctx.builder.build_store(arg_slot, *arg);
arg_slot
};
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg");
let arg_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,
&[int32.const_int(i as u64, false)],
&format!("rpc.arg{}", i),
)
};
ctx.builder.build_store(arg_ptr, arg_slot);
}
// call
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::Generic).into(),
ptr_type.ptr_type(AddressSpace::Generic).into(),
],
false,
),
None,
)
});
ctx.builder.build_call(
rpc_send,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
);
// reclaim stack space used by arguments
ctx.builder.build_call(
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
);
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
});
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
return None
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let mut ret_ty = ctx.get_llvm_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type();
if ret_ty.is_pointer_type() {
ret_ty = ret_ty.into_pointer_type().get_element_type().try_into().unwrap();
}
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot");
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr");
ctx.builder.build_unconditional_branch(head_bb);
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr");
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx.builder.build_int_compare(
inkwell::IntPredicate::EQ,
int32.const_zero(),
alloc_size,
"rpc.done",
);
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb);
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc");
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr");
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb);
ctx.builder.position_at_end(tail_bb);
if need_load {
let result = ctx.builder.build_load(slot, "rpc.result");
ctx.builder.build_call(
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
);
Some(result)
} else {
Some(slot.into())
}
}
pub fn rpc_codegen_callback() -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
})))
}

View File

@ -12,7 +12,7 @@ use inkwell::{
};
use nac3core::typecheck::typedef::{Unifier, TypeEnum};
use nac3parser::{
ast::{self, Stmt, StrRef},
ast::{self, ExprKind, Stmt, StmtKind, StrRef},
parser::{self, parse_program},
};
use pyo3::prelude::*;
@ -24,7 +24,10 @@ use nac3core::{
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
codegen::irrt::load_irrt,
symbol_resolver::SymbolResolver,
toplevel::{composer::{TopLevelComposer, ComposerConfig}, DefinitionId, GenCall, TopLevelDef},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::typedef::{FunSignature, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
};
@ -32,7 +35,7 @@ use nac3core::{
use tempfile::{self, TempDir};
use crate::{
codegen::ArtiqCodeGenerator,
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver},
};
@ -61,6 +64,7 @@ pub struct PrimitivePythonId {
tuple: u64,
typevar: u64,
none: u64,
exception: u64,
generic_alias: (u64, u64),
virtual_id: u64,
}
@ -81,6 +85,7 @@ struct Nac3 {
primitive_ids: PrimitivePythonId,
working_directory: TempDir,
top_levels: Vec<TopLevelComponent>,
string_store: Arc<RwLock<HashMap<String, i32>>>,
}
impl Nac3 {
@ -127,9 +132,13 @@ impl Nac3 {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
match &base.node {
ast::ExprKind::Name { id, .. } => {
let base_obj = module.getattr(py, id.to_string())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
if *id == "Exception".into() {
Ok(true)
} else {
let base_obj = module.getattr(py, id.to_string())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
}
}
_ => Ok(true),
}
@ -143,7 +152,9 @@ impl Nac3 {
{
decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel" || id.to_string() == "portable"
id.to_string() == "kernel"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else {
false
}
@ -159,7 +170,7 @@ impl Nac3 {
} => decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node {
let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel"
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else {
false
}
@ -188,7 +199,7 @@ impl Nac3 {
Ok(ty) => ty,
Err(e) => return Some(format!("type error inside object launching kernel: {}", e))
};
let fun_ty = if method_name.is_empty() {
base_ty
} else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) {
@ -201,7 +212,7 @@ impl Nac3 {
} else {
return Some("cannot launch kernel by calling a non-callable".into())
};
if let TypeEnum::TFunc(sig) = &*unifier.get_ty(fun_ty) {
let FunSignature { args, .. } = &*sig.borrow();
if arg_names.len() > args.len() {
@ -269,7 +280,7 @@ impl Nac3 {
ret: primitive.int64,
vars: HashMap::new(),
},
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _| {
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Some(time_fns.emit_now_mu(ctx))
}))),
),
@ -284,8 +295,9 @@ impl Nac3 {
ret: primitive.none,
vars: HashMap::new(),
},
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| {
time_fns.emit_at_mu(ctx, args[0].1);
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
time_fns.emit_at_mu(ctx, arg);
None
}))),
),
@ -300,16 +312,20 @@ impl Nac3 {
ret: primitive.none,
vars: HashMap::new(),
},
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| {
time_fns.emit_delay_mu(ctx, args[0].1);
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
time_fns.emit_delay_mu(ctx, arg);
None
}))),
),
];
let (_, builtins_def, builtins_ty) = TopLevelComposer::new(builtins.clone(), ComposerConfig {
kernel_ann: Some("Kernel"),
kernel_invariant_ann: "KernelInvariant"
});
let (_, builtins_def, builtins_ty) = TopLevelComposer::new(
builtins.clone(),
ComposerConfig {
kernel_ann: Some("Kernel"),
kernel_invariant_ann: "KernelInvariant",
},
);
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
@ -385,6 +401,11 @@ impl Nac3 {
.unwrap()
.extract()
.unwrap(),
exception: id_fn
.call1((builtins_mod.getattr("tuple").unwrap(),))
.unwrap()
.extract()
.unwrap(),
};
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
@ -405,6 +426,7 @@ impl Nac3 {
top_levels: Default::default(),
pyid_to_def: Default::default(),
working_directory,
string_store: Default::default()
})
}
@ -441,6 +463,7 @@ impl Nac3 {
method_name: &str,
args: Vec<&PyAny>,
filename: &str,
embedding_map: &PyAny,
py: Python,
) -> PyResult<()> {
let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig {
@ -451,17 +474,26 @@ impl Nac3 {
let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?;
let id_fn = builtins.getattr("id")?;
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
let store_fun = embedding_map
.getattr("store_function")
.unwrap()
.to_object(py);
let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap().to_object(py),
len_fn: builtins.getattr("len").unwrap().to_object(py),
type_fn: builtins.getattr("type").unwrap().to_object(py),
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
store_obj,
store_str
};
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let pyid_to_type = Arc::new(RwLock::new(HashMap::<u64, Type>::new()));
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
let mut rpc_ids = vec![];
for (stmt, path, module) in self.top_levels.iter() {
let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
@ -492,6 +524,7 @@ impl Nac3 {
id_to_primitive: Default::default(),
field_to_val: Default::default(),
helper,
string_store: self.string_store.clone(),
})))
as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid);
@ -502,7 +535,30 @@ impl Nac3 {
let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
.map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?;
.map_err(|e| {
exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e))
})?;
match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string()).unwrap())).unwrap();
rpc_ids.push((None, def_id));
}
}
StmtKind::ClassDef { name, body, .. } => {
let class_obj = module.getattr(py, name.to_string()).unwrap();
for stmt in body.iter() {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
}
}
}
}
_ => ()
}
let id = *name_to_pyid.get(&name).unwrap();
self.pyid_to_def.write().insert(id, def_id);
{
@ -552,6 +608,7 @@ impl Nac3 {
name_to_pyid,
module: module.to_object(py),
helper,
string_store: self.string_store.clone(),
}))) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer
.register_top_level(
@ -595,6 +652,45 @@ impl Nac3 {
}
}
let top_level = Arc::new(composer.make_top_level_context());
{
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read();
for (class_data, id) in rpc_ids.iter() {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function {
codegen_callback, ..
} => {
*codegen_callback = Some(rpc_codegen.clone());
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
for (name, _, id) in methods.iter() {
if name != method_name {
continue;
}
if let TopLevelDef::Function {
codegen_callback, ..
} = &mut *defs[id.0].write()
{
*codegen_callback = Some(rpc_codegen.clone());
store_fun
.call1(
py,
(
id.0.into_py(py),
class_def.getattr(py, name.to_string()).unwrap(),
),
)
.unwrap();
}
}
}
}
}
}
let instance = {
let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write();
@ -634,15 +730,17 @@ impl Nac3 {
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
})));
let size_t = if self.isa == Isa::Host {
64
} else {
32
};
let size_t = if self.isa == Isa::Host { 64 } else { 32 };
let thread_names: Vec<String> = (0..4).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names
.iter()
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
.map(|s| {
Box::new(ArtiqCodeGenerator::new(
s.to_string(),
size_t,
self.time_fns,
))
})
.collect();
py.allow_threads(|| {
@ -759,11 +857,12 @@ impl Nac3 {
obj: &PyAny,
method_name: &str,
args: Vec<&PyAny>,
embedding_map: &PyAny,
py: Python,
) -> PyResult<PyObject> {
let filename_path = self.working_directory.path().join("module.elf");
let filename = filename_path.to_str().unwrap();
self.compile_method_to_file(obj, method_name, args, filename, py)?;
self.compile_method_to_file(obj, method_name, args, filename, embedding_map, py)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
}
}

View File

@ -42,6 +42,7 @@ pub struct InnerResolver {
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
pub primitive_ids: PrimitivePythonId,
pub helper: PythonHelper,
pub string_store: Arc<RwLock<HashMap<String, i32>>>,
// module specific
pub name_to_pyid: HashMap<StrRef, u64>,
pub module: PyObject,
@ -56,11 +57,14 @@ pub struct PythonHelper {
pub id_fn: PyObject,
pub origin_ty_fn: PyObject,
pub args_ty_fn: PyObject,
pub store_obj: PyObject,
pub store_str: PyObject,
}
struct PythonValue {
id: u64,
value: PyObject,
store_obj: PyObject,
resolver: Arc<InnerResolver>,
}
@ -69,6 +73,36 @@ impl StaticValue for PythonValue {
self.id
}
fn get_const_obj<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
_: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx> {
ctx.module
.get_global(self.id.to_string().as_str())
.map(|val| val.as_pointer_value().into())
.unwrap_or_else(|| {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
let global =
ctx.module
.add_global(struct_type, None, format!("{}_const", self.id).as_str());
global.set_constant(true);
global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(id as u64, false).into()],
false,
));
let global2 =
ctx.module
.add_global(struct_type.ptr_type(AddressSpace::Generic), None, format!("{}_const2", self.id).as_str());
global2.set_initializer(&global.as_pointer_value());
Ok(global2.as_pointer_value().into())
})
.unwrap()
})
}
fn to_basic_value_enum<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
@ -140,6 +174,7 @@ impl StaticValue for PythonValue {
ValueEnum::Static(Arc::new(PythonValue {
id,
value: obj,
store_obj: self.store_obj.clone(),
resolver: self.resolver.clone(),
}))
})
@ -208,7 +243,9 @@ impl InnerResolver {
Ok(Ok((primitives.bool, true)))
} else if ty_id == self.primitive_ids.float {
Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.list {
} else if ty_id == self.primitive_ids.exception {
Ok(Ok((primitives.exception, true)))
}else if ty_id == self.primitive_ids.list {
// do not handle type var param and concrete check here
let var = unifier.get_fresh_var().0;
let list = unifier.add_ty(TypeEnum::TList { ty: var });
@ -755,9 +792,7 @@ impl InnerResolver {
.get_llvm_type(generator, ty)
.into_pointer_type()
.get_element_type()
.into_struct_type()
.as_basic_type_enum();
.into_struct_type();
{
if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
@ -783,7 +818,7 @@ impl InnerResolver {
.collect();
let values = values?;
if let Some(values) = values {
let val = ctx.ctx.const_struct(&values, false);
let val = ty.const_named_struct(&values);
let global = ctx
.module
.add_global(ty, Some(AddressSpace::Generic), &id_str);
@ -948,6 +983,7 @@ impl SymbolResolver for Resolver {
ValueEnum::Static(Arc::new(PythonValue {
id,
value: v,
store_obj: self.0.helper.store_obj.clone(),
resolver: self.0.clone(),
}))
})
@ -971,4 +1007,17 @@ impl SymbolResolver for Resolver {
result
})
}
fn get_string_id(&self, s: &str) -> i32 {
let mut string_store = self.0.string_store.write();
if let Some(id) = string_store.get(s) {
*id
} else {
let id = Python::with_gil(|py| -> PyResult<i32> {
self.0.helper.store_str.call1(py, (s, ))?.extract(py)
}).unwrap();
string_store.insert(s.into(), id);
id
}
}
}

View File

@ -32,10 +32,16 @@ pub enum SymbolValue {
pub trait StaticValue {
fn get_unique_identifier(&self) -> u64;
fn get_const_obj<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>;
fn to_basic_value_enum<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator
generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>;
fn get_field<'ctx, 'a>(