1
0
forked from M-Labs/nac3

nac3artiq: RPC support

This commit is contained in:
pca006132 2022-02-12 21:17:37 +08:00
parent e303248261
commit bf52e294ee
6 changed files with 577 additions and 81 deletions

View File

@ -0,0 +1,57 @@
class EmbeddingMap:
def __init__(self):
self.object_inverse_map = {}
self.object_map = {}
self.string_map = {}
self.string_reverse_map = {}
self.function_map = {}
# preallocate exception names
self.preallocate_runtime_exception_names(["RuntimeError",
"RTIOUnderflow",
"RTIOOverflow",
"RTIODestinationUnreachable",
"DMAError",
"I2CError",
"CacheError",
"SPIError",
"0:ZeroDivisionError",
"0:IndexError"])
def preallocate_runtime_exception_names(self, names):
for i, name in enumerate(names):
if ":" not in name:
name = "0:artiq.coredevice.exceptions." + name
exn_id = self.store_str(name)
assert exn_id == i
def store_function(self, key, fun):
self.function_map[key] = fun
return key
def store_object(self, obj):
obj_id = id(obj)
if obj_id in self.object_inverse_map:
return self.object_inverse_map[obj_id]
key = len(self.object_map)
self.object_map[key] = obj
self.object_inverse_map[obj_id] = key
return key
def store_str(self, s):
if s in self.string_reverse_map:
return self.string_reverse_map[s]
key = len(self.string_map)
self.string_map[key] = s
self.string_reverse_map[s] = key
return key
def retrieve_function(self, key):
return self.function_map[key]
def retrieve_object(self, key):
return self.object_map[key]
def retrieve_str(self, key):
return self.string_map[key]

View File

@ -6,13 +6,14 @@ from typing import Generic, TypeVar
from math import floor, ceil from math import floor, ceil
import nac3artiq import nac3artiq
from embedding_map import EmbeddingMap
__all__ = [ __all__ = [
"Kernel", "KernelInvariant", "virtual", "Kernel", "KernelInvariant", "virtual",
"round64", "floor64", "ceil64", "round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3", "extern", "kernel", "portable", "nac3",
"ms", "us", "ns", "rpc", "ms", "us", "ns",
"print_int32", "print_int64", "print_int32", "print_int64",
"Core", "TTLOut", "Core", "TTLOut",
"parallel", "sequential" "parallel", "sequential"
@ -65,6 +66,10 @@ def extern(function):
register_function(function) register_function(function)
return function return function
def rpc(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""
@ -146,6 +151,9 @@ class Core:
def run(self, method, *args, **kwargs): def run(self, method, *args, **kwargs):
global allow_registration global allow_registration
embedding = EmbeddingMap()
if allow_registration: if allow_registration:
compiler.analyze(registered_functions, registered_classes) compiler.analyze(registered_functions, registered_classes)
allow_registration = False allow_registration = False
@ -157,7 +165,7 @@ class Core:
obj = method obj = method
name = "" name = ""
compiler.compile_method_to_file(obj, name, args, "module.elf") compiler.compile_method_to_file(obj, name, args, "module.elf", embedding)
@kernel @kernel
def reset(self): def reset(self):

View File

@ -1,16 +1,30 @@
use nac3core::{ use nac3core::{
codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator}, codegen::{
expr::gen_call,
stmt::{gen_block, gen_with},
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::DefinitionId, toplevel::{DefinitionId, GenCall},
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{context::Context, types::IntType, values::BasicValueEnum}; use inkwell::{
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
};
use crate::timeline::TimeFns; use crate::timeline::TimeFns;
use std::{
collections::hash_map::DefaultHasher,
collections::HashMap,
convert::TryInto,
hash::{Hash, Hasher},
sync::Arc,
};
pub struct ArtiqCodeGenerator<'a> { pub struct ArtiqCodeGenerator<'a> {
name: String, name: String,
size_t: u32, size_t: u32,
@ -21,16 +35,13 @@ pub struct ArtiqCodeGenerator<'a> {
} }
impl<'a> ArtiqCodeGenerator<'a> { impl<'a> ArtiqCodeGenerator<'a> {
pub fn new(name: String, size_t: u32, timeline: &'a (dyn TimeFns + Sync)) -> ArtiqCodeGenerator<'a> { pub fn new(
name: String,
size_t: u32,
timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> {
assert!(size_t == 32 || size_t == 64); assert!(size_t == 32 || size_t == 64);
ArtiqCodeGenerator { ArtiqCodeGenerator { name, size_t, name_counter: 0, start: None, end: None, timeline }
name,
size_t,
name_counter: 0,
start: None,
end: None,
timeline,
}
} }
} }
@ -86,7 +97,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> bool { ) {
if let StmtKind::With { items, body, .. } = &stmt.node { if let StmtKind::With { items, body, .. } = &stmt.node {
if items.len() == 1 && items[0].optional_vars.is_none() { if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0]; let item = &items[0];
@ -108,9 +119,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let old_start = self.start.take(); let old_start = self.start.take();
let old_end = self.end.take(); let old_end = self.end.take();
let now = if let Some(old_start) = &old_start { let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start) self.gen_expr(ctx, old_start).unwrap().to_basic_value_enum(ctx, self)
.unwrap()
.to_basic_value_enum(ctx, self)
} else { } else {
self.timeline.emit_now_mu(ctx) self.timeline.emit_now_mu(ctx)
}; };
@ -126,10 +135,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let start_expr = Located { let start_expr = Located {
// location does not matter at this point // location does not matter at this point
location: stmt.location, location: stmt.location,
node: ExprKind::Name { node: ExprKind::Name { id: start, ctx: name_ctx.clone() },
id: start,
ctx: name_ctx.clone(),
},
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let start = self.gen_store_target(ctx, &start_expr); let start = self.gen_store_target(ctx, &start_expr);
@ -140,40 +146,41 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let end_expr = Located { let end_expr = Located {
// location does not matter at this point // location does not matter at this point
location: stmt.location, location: stmt.location,
node: ExprKind::Name { node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
id: end,
ctx: name_ctx.clone(),
},
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let end = self.gen_store_target(ctx, &end_expr); let end = self.gen_store_target(ctx, &end_expr);
ctx.builder.build_store(end, now); ctx.builder.build_store(end, now);
self.end = Some(end_expr); self.end = Some(end_expr);
self.name_counter += 1; self.name_counter += 1;
let mut exited = false; gen_block(self, ctx, body.iter());
for stmt in body.iter() { let current = ctx.builder.get_insert_block().unwrap();
if self.gen_stmt(ctx, stmt) { // if the current block is terminated, move before the terminator
exited = true; // we want to set the timeline before reaching the terminator
break; // TODO: This may be unsound if there are multiple exit paths in the
} // block... e.g.
} // if ...:
// return
// Perhaps we can fix this by using actual with block?
let reset_position = if let Some(terminator) = current.get_terminator() {
ctx.builder.position_before(&terminator);
true
} else {
false
};
// set duration // set duration
let end_expr = self.end.take().unwrap(); let end_expr = self.end.take().unwrap();
let end_val = self let end_val =
.gen_expr(ctx, &end_expr) self.gen_expr(ctx, &end_expr).unwrap().to_basic_value_enum(ctx, self);
.unwrap()
.to_basic_value_enum(ctx, self);
// inside an sequential block // inside a sequential block
if old_start.is_none() { if old_start.is_none() {
self.timeline.emit_at_mu(ctx, end_val); self.timeline.emit_at_mu(ctx, end_val);
} }
// inside a parallel block, should update the outer max now_mu // inside a parallel block, should update the outer max now_mu
if let Some(old_end) = &old_end { if let Some(old_end) = &old_end {
let outer_end_val = self let outer_end_val =
.gen_expr(ctx, old_end) self.gen_expr(ctx, old_end).unwrap().to_basic_value_enum(ctx, self);
.unwrap()
.to_basic_value_enum(ctx, self);
let smax = let smax =
ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
let i64 = ctx.ctx.i64_type(); let i64 = ctx.ctx.i64_type();
@ -194,24 +201,294 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
} }
self.start = old_start; self.start = old_start;
self.end = old_end; self.end = old_end;
return exited; if reset_position {
ctx.builder.position_at_end(current);
}
return;
} else if id == &"sequential".into() { } else if id == &"sequential".into() {
let start = self.start.take(); let start = self.start.take();
for stmt in body.iter() { for stmt in body.iter() {
if self.gen_stmt(ctx, stmt) { self.gen_stmt(ctx, stmt);
self.start = start; if ctx.is_terminated() {
return true; break;
} }
} }
self.start = start; self.start = start;
return false; return
} }
} }
} }
// not parallel/sequential // not parallel/sequential
gen_with(self, ctx, stmt) gen_with(self, ctx, stmt);
} else { } else {
unreachable!() unreachable!()
} }
} }
} }
fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec<u8>) {
use nac3core::typecheck::typedef::TypeEnum::*;
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let float = ctx.primitives.float;
let bool = ctx.primitives.bool;
let str = ctx.primitives.str;
let none = ctx.primitives.none;
if ctx.unifier.unioned(ty, int32) {
buffer.push(b'i');
} else if ctx.unifier.unioned(ty, int64) {
buffer.push(b'I');
} else if ctx.unifier.unioned(ty, float) {
buffer.push(b'f');
} else if ctx.unifier.unioned(ty, bool) {
buffer.push(b'b');
} else if ctx.unifier.unioned(ty, str) {
buffer.push(b's');
} else if ctx.unifier.unioned(ty, none) {
buffer.push(b'n');
} else {
let ty = ctx.unifier.get_ty(ty);
match &*ty {
TTuple { ty } => {
buffer.push(b't');
buffer.push(ty.len() as u8);
for ty in ty {
gen_rpc_tag(ctx, *ty, buffer);
}
}
TList { ty } => {
buffer.push(b'l');
gen_rpc_tag(ctx, *ty, buffer);
}
// we should return an error, this will be fixed after improving error message
// as this requires returning an error during codegen
_ => unimplemented!(),
}
}
}
fn rpc_codegen_callback_fn<'ctx, 'a>(
ctx: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Option<BasicValueEnum<'ctx>> {
let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic);
let size_type = generator.get_size_type(ctx.ctx);
let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1.0 as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
tag.push(b'O');
}
for arg in fun.0.args.iter() {
gen_rpc_tag(ctx, arg.ty, &mut tag);
}
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, &mut tag);
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish());
let tag_ptr = ctx
.module
.get_global(hash.as_str())
.unwrap_or_else(|| {
let tag_arr_ptr = ctx.module.add_global(
int8.array_type(tag.len() as u32),
None,
format!("tagptr{}", fun.1 .0).as_str(),
);
tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(),
));
tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(),
],
false,
));
tag_ptr
})
.as_pointer_value();
let arg_length = args.len() + if obj.is_some() { 1 } else { 0 };
let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| {
ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None)
});
let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| {
ctx.module.add_function(
"llvm.stackrestore",
ctx.ctx.void_type().fn_type(&[ptr_type.into()], false),
None,
)
});
let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack");
let args_ptr = ctx.builder.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
);
// -- rpc args handling
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in args.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into());
}
// reorder the parameters
let mut real_params = fun
.0
.args
.iter()
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator))
.collect::<Vec<_>>();
if let Some(obj) = obj {
if let ValueEnum::Static(obj) = obj.1 {
real_params.insert(0, obj.get_const_obj(ctx, generator));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, arg) in real_params.iter().enumerate() {
let arg_slot = if arg.is_pointer_value() {
arg.into_pointer_value()
} else {
let arg_slot = ctx.builder.build_alloca(arg.get_type(), &format!("rpc.arg{}", i));
ctx.builder.build_store(arg_slot, *arg);
arg_slot
};
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg");
let arg_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,
&[int32.const_int(i as u64, false)],
&format!("rpc.arg{}", i),
)
};
ctx.builder.build_store(arg_ptr, arg_slot);
}
// call
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::Generic).into(),
ptr_type.ptr_type(AddressSpace::Generic).into(),
],
false,
),
None,
)
});
ctx.builder.build_call(
rpc_send,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
);
// reclaim stack space used by arguments
ctx.builder.build_call(
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
);
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
});
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
return None
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let mut ret_ty = ctx.get_llvm_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type();
if ret_ty.is_pointer_type() {
ret_ty = ret_ty.into_pointer_type().get_element_type().try_into().unwrap();
}
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot");
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr");
ctx.builder.build_unconditional_branch(head_bb);
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr");
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx.builder.build_int_compare(
inkwell::IntPredicate::EQ,
int32.const_zero(),
alloc_size,
"rpc.done",
);
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb);
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc");
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr");
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb);
ctx.builder.position_at_end(tail_bb);
if need_load {
let result = ctx.builder.build_load(slot, "rpc.result");
ctx.builder.build_call(
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
);
Some(result)
} else {
Some(slot.into())
}
}
pub fn rpc_codegen_callback() -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
})))
}

View File

@ -12,7 +12,7 @@ use inkwell::{
}; };
use nac3core::typecheck::typedef::{Unifier, TypeEnum}; use nac3core::typecheck::typedef::{Unifier, TypeEnum};
use nac3parser::{ use nac3parser::{
ast::{self, Stmt, StrRef}, ast::{self, ExprKind, Stmt, StmtKind, StrRef},
parser::{self, parse_program}, parser::{self, parse_program},
}; };
use pyo3::prelude::*; use pyo3::prelude::*;
@ -24,7 +24,10 @@ use nac3core::{
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
codegen::irrt::load_irrt, codegen::irrt::load_irrt,
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{composer::{TopLevelComposer, ComposerConfig}, DefinitionId, GenCall, TopLevelDef}, toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::typedef::{FunSignature, FuncArg}, typecheck::typedef::{FunSignature, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type}, typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
}; };
@ -32,7 +35,7 @@ use nac3core::{
use tempfile::{self, TempDir}; use tempfile::{self, TempDir};
use crate::{ use crate::{
codegen::ArtiqCodeGenerator, codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver}, symbol_resolver::{InnerResolver, PythonHelper, Resolver},
}; };
@ -61,6 +64,7 @@ pub struct PrimitivePythonId {
tuple: u64, tuple: u64,
typevar: u64, typevar: u64,
none: u64, none: u64,
exception: u64,
generic_alias: (u64, u64), generic_alias: (u64, u64),
virtual_id: u64, virtual_id: u64,
} }
@ -81,6 +85,7 @@ struct Nac3 {
primitive_ids: PrimitivePythonId, primitive_ids: PrimitivePythonId,
working_directory: TempDir, working_directory: TempDir,
top_levels: Vec<TopLevelComponent>, top_levels: Vec<TopLevelComponent>,
string_store: Arc<RwLock<HashMap<String, i32>>>,
} }
impl Nac3 { impl Nac3 {
@ -127,10 +132,14 @@ impl Nac3 {
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
match &base.node { match &base.node {
ast::ExprKind::Name { id, .. } => { ast::ExprKind::Name { id, .. } => {
if *id == "Exception".into() {
Ok(true)
} else {
let base_obj = module.getattr(py, id.to_string())?; let base_obj = module.getattr(py, id.to_string())?;
let base_id = id_fn.call1((base_obj,))?.extract()?; let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id)) Ok(registered_class_ids.contains(&base_id))
} }
}
_ => Ok(true), _ => Ok(true),
} }
}) })
@ -143,7 +152,9 @@ impl Nac3 {
{ {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node { if let ast::ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel" || id.to_string() == "portable" id.to_string() == "kernel"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else { } else {
false false
} }
@ -159,7 +170,7 @@ impl Nac3 {
} => decorator_list.iter().any(|decorator| { } => decorator_list.iter().any(|decorator| {
if let ast::ExprKind::Name { id, .. } = decorator.node { if let ast::ExprKind::Name { id, .. } = decorator.node {
let id = id.to_string(); let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel" id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else { } else {
false false
} }
@ -269,7 +280,7 @@ impl Nac3 {
ret: primitive.int64, ret: primitive.int64,
vars: HashMap::new(), vars: HashMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Some(time_fns.emit_now_mu(ctx)) Some(time_fns.emit_now_mu(ctx))
}))), }))),
), ),
@ -284,8 +295,9 @@ impl Nac3 {
ret: primitive.none, ret: primitive.none,
vars: HashMap::new(), vars: HashMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| {
time_fns.emit_at_mu(ctx, args[0].1); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
time_fns.emit_at_mu(ctx, arg);
None None
}))), }))),
), ),
@ -300,16 +312,20 @@ impl Nac3 {
ret: primitive.none, ret: primitive.none,
vars: HashMap::new(), vars: HashMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| {
time_fns.emit_delay_mu(ctx, args[0].1); let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
time_fns.emit_delay_mu(ctx, arg);
None None
}))), }))),
), ),
]; ];
let (_, builtins_def, builtins_ty) = TopLevelComposer::new(builtins.clone(), ComposerConfig { let (_, builtins_def, builtins_ty) = TopLevelComposer::new(
builtins.clone(),
ComposerConfig {
kernel_ann: Some("Kernel"), kernel_ann: Some("Kernel"),
kernel_invariant_ann: "KernelInvariant" kernel_invariant_ann: "KernelInvariant",
}); },
);
let builtins_mod = PyModule::import(py, "builtins").unwrap(); let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap();
@ -385,6 +401,11 @@ impl Nac3 {
.unwrap() .unwrap()
.extract() .extract()
.unwrap(), .unwrap(),
exception: id_fn
.call1((builtins_mod.getattr("tuple").unwrap(),))
.unwrap()
.extract()
.unwrap(),
}; };
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
@ -405,6 +426,7 @@ impl Nac3 {
top_levels: Default::default(), top_levels: Default::default(),
pyid_to_def: Default::default(), pyid_to_def: Default::default(),
working_directory, working_directory,
string_store: Default::default()
}) })
} }
@ -441,6 +463,7 @@ impl Nac3 {
method_name: &str, method_name: &str,
args: Vec<&PyAny>, args: Vec<&PyAny>,
filename: &str, filename: &str,
embedding_map: &PyAny,
py: Python, py: Python,
) -> PyResult<()> { ) -> PyResult<()> {
let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig { let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig {
@ -451,17 +474,26 @@ impl Nac3 {
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
let typings = PyModule::import(py, "typing")?; let typings = PyModule::import(py, "typing")?;
let id_fn = builtins.getattr("id")?; let id_fn = builtins.getattr("id")?;
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
let store_fun = embedding_map
.getattr("store_function")
.unwrap()
.to_object(py);
let helper = PythonHelper { let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap().to_object(py), id_fn: builtins.getattr("id").unwrap().to_object(py),
len_fn: builtins.getattr("len").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py),
type_fn: builtins.getattr("type").unwrap().to_object(py), type_fn: builtins.getattr("type").unwrap().to_object(py),
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
store_obj,
store_str
}; };
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let pyid_to_type = Arc::new(RwLock::new(HashMap::<u64, Type>::new())); let pyid_to_type = Arc::new(RwLock::new(HashMap::<u64, Type>::new()));
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new())); let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
let mut rpc_ids = vec![];
for (stmt, path, module) in self.top_levels.iter() { for (stmt, path, module) in self.top_levels.iter() {
let py_module: &PyAny = module.extract(py)?; let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
@ -492,6 +524,7 @@ impl Nac3 {
id_to_primitive: Default::default(), id_to_primitive: Default::default(),
field_to_val: Default::default(), field_to_val: Default::default(),
helper, helper,
string_store: self.string_store.clone(),
}))) })))
as Arc<dyn SymbolResolver + Send + Sync>; as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid); let name_to_pyid = Rc::new(name_to_pyid);
@ -502,7 +535,30 @@ impl Nac3 {
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
.map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?; .map_err(|e| {
exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e))
})?;
match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string()).unwrap())).unwrap();
rpc_ids.push((None, def_id));
}
}
StmtKind::ClassDef { name, body, .. } => {
let class_obj = module.getattr(py, name.to_string()).unwrap();
for stmt in body.iter() {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
}
}
}
}
_ => ()
}
let id = *name_to_pyid.get(&name).unwrap(); let id = *name_to_pyid.get(&name).unwrap();
self.pyid_to_def.write().insert(id, def_id); self.pyid_to_def.write().insert(id, def_id);
{ {
@ -552,6 +608,7 @@ impl Nac3 {
name_to_pyid, name_to_pyid,
module: module.to_object(py), module: module.to_object(py),
helper, helper,
string_store: self.string_store.clone(),
}))) as Arc<dyn SymbolResolver + Send + Sync>; }))) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer let (_, def_id, _) = composer
.register_top_level( .register_top_level(
@ -595,6 +652,45 @@ impl Nac3 {
} }
} }
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
{
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read();
for (class_data, id) in rpc_ids.iter() {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function {
codegen_callback, ..
} => {
*codegen_callback = Some(rpc_codegen.clone());
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
for (name, _, id) in methods.iter() {
if name != method_name {
continue;
}
if let TopLevelDef::Function {
codegen_callback, ..
} = &mut *defs[id.0].write()
{
*codegen_callback = Some(rpc_codegen.clone());
store_fun
.call1(
py,
(
id.0.into_py(py),
class_def.getattr(py, name.to_string()).unwrap(),
),
)
.unwrap();
}
}
}
}
}
}
let instance = { let instance = {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write(); let mut definition = defs[def_id.0].write();
@ -634,15 +730,17 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = if self.isa == Isa::Host { let size_t = if self.isa == Isa::Host { 64 } else { 32 };
64
} else {
32
};
let thread_names: Vec<String> = (0..4).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..4).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
.iter() .iter()
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) .map(|s| {
Box::new(ArtiqCodeGenerator::new(
s.to_string(),
size_t,
self.time_fns,
))
})
.collect(); .collect();
py.allow_threads(|| { py.allow_threads(|| {
@ -759,11 +857,12 @@ impl Nac3 {
obj: &PyAny, obj: &PyAny,
method_name: &str, method_name: &str,
args: Vec<&PyAny>, args: Vec<&PyAny>,
embedding_map: &PyAny,
py: Python, py: Python,
) -> PyResult<PyObject> { ) -> PyResult<PyObject> {
let filename_path = self.working_directory.path().join("module.elf"); let filename_path = self.working_directory.path().join("module.elf");
let filename = filename_path.to_str().unwrap(); let filename = filename_path.to_str().unwrap();
self.compile_method_to_file(obj, method_name, args, filename, py)?; self.compile_method_to_file(obj, method_name, args, filename, embedding_map, py)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
} }
} }

View File

@ -42,6 +42,7 @@ pub struct InnerResolver {
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>, pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
pub primitive_ids: PrimitivePythonId, pub primitive_ids: PrimitivePythonId,
pub helper: PythonHelper, pub helper: PythonHelper,
pub string_store: Arc<RwLock<HashMap<String, i32>>>,
// module specific // module specific
pub name_to_pyid: HashMap<StrRef, u64>, pub name_to_pyid: HashMap<StrRef, u64>,
pub module: PyObject, pub module: PyObject,
@ -56,11 +57,14 @@ pub struct PythonHelper {
pub id_fn: PyObject, pub id_fn: PyObject,
pub origin_ty_fn: PyObject, pub origin_ty_fn: PyObject,
pub args_ty_fn: PyObject, pub args_ty_fn: PyObject,
pub store_obj: PyObject,
pub store_str: PyObject,
} }
struct PythonValue { struct PythonValue {
id: u64, id: u64,
value: PyObject, value: PyObject,
store_obj: PyObject,
resolver: Arc<InnerResolver>, resolver: Arc<InnerResolver>,
} }
@ -69,6 +73,36 @@ impl StaticValue for PythonValue {
self.id self.id
} }
fn get_const_obj<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
_: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx> {
ctx.module
.get_global(self.id.to_string().as_str())
.map(|val| val.as_pointer_value().into())
.unwrap_or_else(|| {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
let global =
ctx.module
.add_global(struct_type, None, format!("{}_const", self.id).as_str());
global.set_constant(true);
global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(id as u64, false).into()],
false,
));
let global2 =
ctx.module
.add_global(struct_type.ptr_type(AddressSpace::Generic), None, format!("{}_const2", self.id).as_str());
global2.set_initializer(&global.as_pointer_value());
Ok(global2.as_pointer_value().into())
})
.unwrap()
})
}
fn to_basic_value_enum<'ctx, 'a>( fn to_basic_value_enum<'ctx, 'a>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
@ -140,6 +174,7 @@ impl StaticValue for PythonValue {
ValueEnum::Static(Arc::new(PythonValue { ValueEnum::Static(Arc::new(PythonValue {
id, id,
value: obj, value: obj,
store_obj: self.store_obj.clone(),
resolver: self.resolver.clone(), resolver: self.resolver.clone(),
})) }))
}) })
@ -208,7 +243,9 @@ impl InnerResolver {
Ok(Ok((primitives.bool, true))) Ok(Ok((primitives.bool, true)))
} else if ty_id == self.primitive_ids.float { } else if ty_id == self.primitive_ids.float {
Ok(Ok((primitives.float, true))) Ok(Ok((primitives.float, true)))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.exception {
Ok(Ok((primitives.exception, true)))
}else if ty_id == self.primitive_ids.list {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
let var = unifier.get_fresh_var().0; let var = unifier.get_fresh_var().0;
let list = unifier.add_ty(TypeEnum::TList { ty: var }); let list = unifier.add_ty(TypeEnum::TList { ty: var });
@ -755,9 +792,7 @@ impl InnerResolver {
.get_llvm_type(generator, ty) .get_llvm_type(generator, ty)
.into_pointer_type() .into_pointer_type()
.get_element_type() .get_element_type()
.into_struct_type() .into_struct_type();
.as_basic_type_enum();
{ {
if self.global_value_ids.read().contains(&id) { if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
@ -783,7 +818,7 @@ impl InnerResolver {
.collect(); .collect();
let values = values?; let values = values?;
if let Some(values) = values { if let Some(values) = values {
let val = ctx.ctx.const_struct(&values, false); let val = ty.const_named_struct(&values);
let global = ctx let global = ctx
.module .module
.add_global(ty, Some(AddressSpace::Generic), &id_str); .add_global(ty, Some(AddressSpace::Generic), &id_str);
@ -948,6 +983,7 @@ impl SymbolResolver for Resolver {
ValueEnum::Static(Arc::new(PythonValue { ValueEnum::Static(Arc::new(PythonValue {
id, id,
value: v, value: v,
store_obj: self.0.helper.store_obj.clone(),
resolver: self.0.clone(), resolver: self.0.clone(),
})) }))
}) })
@ -971,4 +1007,17 @@ impl SymbolResolver for Resolver {
result result
}) })
} }
fn get_string_id(&self, s: &str) -> i32 {
let mut string_store = self.0.string_store.write();
if let Some(id) = string_store.get(s) {
*id
} else {
let id = Python::with_gil(|py| -> PyResult<i32> {
self.0.helper.store_str.call1(py, (s, ))?.extract(py)
}).unwrap();
string_store.insert(s.into(), id);
id
}
}
} }

View File

@ -32,10 +32,16 @@ pub enum SymbolValue {
pub trait StaticValue { pub trait StaticValue {
fn get_unique_identifier(&self) -> u64; fn get_unique_identifier(&self) -> u64;
fn get_const_obj<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>;
fn to_basic_value_enum<'ctx, 'a>( fn to_basic_value_enum<'ctx, 'a>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>; ) -> BasicValueEnum<'ctx>;
fn get_field<'ctx, 'a>( fn get_field<'ctx, 'a>(