nac3artiq: RPC support
This commit is contained in:
parent
e303248261
commit
bf52e294ee
57
nac3artiq/demo/embedding_map.py
Normal file
57
nac3artiq/demo/embedding_map.py
Normal file
@ -0,0 +1,57 @@
|
||||
class EmbeddingMap:
|
||||
def __init__(self):
|
||||
self.object_inverse_map = {}
|
||||
self.object_map = {}
|
||||
self.string_map = {}
|
||||
self.string_reverse_map = {}
|
||||
self.function_map = {}
|
||||
|
||||
# preallocate exception names
|
||||
self.preallocate_runtime_exception_names(["RuntimeError",
|
||||
"RTIOUnderflow",
|
||||
"RTIOOverflow",
|
||||
"RTIODestinationUnreachable",
|
||||
"DMAError",
|
||||
"I2CError",
|
||||
"CacheError",
|
||||
"SPIError",
|
||||
"0:ZeroDivisionError",
|
||||
"0:IndexError"])
|
||||
|
||||
def preallocate_runtime_exception_names(self, names):
|
||||
for i, name in enumerate(names):
|
||||
if ":" not in name:
|
||||
name = "0:artiq.coredevice.exceptions." + name
|
||||
exn_id = self.store_str(name)
|
||||
assert exn_id == i
|
||||
|
||||
def store_function(self, key, fun):
|
||||
self.function_map[key] = fun
|
||||
return key
|
||||
|
||||
def store_object(self, obj):
|
||||
obj_id = id(obj)
|
||||
if obj_id in self.object_inverse_map:
|
||||
return self.object_inverse_map[obj_id]
|
||||
key = len(self.object_map)
|
||||
self.object_map[key] = obj
|
||||
self.object_inverse_map[obj_id] = key
|
||||
return key
|
||||
|
||||
def store_str(self, s):
|
||||
if s in self.string_reverse_map:
|
||||
return self.string_reverse_map[s]
|
||||
key = len(self.string_map)
|
||||
self.string_map[key] = s
|
||||
self.string_reverse_map[s] = key
|
||||
return key
|
||||
|
||||
def retrieve_function(self, key):
|
||||
return self.function_map[key]
|
||||
|
||||
def retrieve_object(self, key):
|
||||
return self.object_map[key]
|
||||
|
||||
def retrieve_str(self, key):
|
||||
return self.string_map[key]
|
||||
|
@ -6,13 +6,14 @@ from typing import Generic, TypeVar
|
||||
from math import floor, ceil
|
||||
|
||||
import nac3artiq
|
||||
from embedding_map import EmbeddingMap
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Kernel", "KernelInvariant", "virtual",
|
||||
"round64", "floor64", "ceil64",
|
||||
"extern", "kernel", "portable", "nac3",
|
||||
"ms", "us", "ns",
|
||||
"rpc", "ms", "us", "ns",
|
||||
"print_int32", "print_int64",
|
||||
"Core", "TTLOut",
|
||||
"parallel", "sequential"
|
||||
@ -65,6 +66,10 @@ def extern(function):
|
||||
register_function(function)
|
||||
return function
|
||||
|
||||
def rpc(function):
|
||||
"""Decorates a function declaration defined by the core device runtime."""
|
||||
register_function(function)
|
||||
return function
|
||||
|
||||
def kernel(function_or_method):
|
||||
"""Decorates a function or method to be executed on the core device."""
|
||||
@ -146,6 +151,9 @@ class Core:
|
||||
|
||||
def run(self, method, *args, **kwargs):
|
||||
global allow_registration
|
||||
|
||||
embedding = EmbeddingMap()
|
||||
|
||||
if allow_registration:
|
||||
compiler.analyze(registered_functions, registered_classes)
|
||||
allow_registration = False
|
||||
@ -157,7 +165,7 @@ class Core:
|
||||
obj = method
|
||||
name = ""
|
||||
|
||||
compiler.compile_method_to_file(obj, name, args, "module.elf")
|
||||
compiler.compile_method_to_file(obj, name, args, "module.elf", embedding)
|
||||
|
||||
@kernel
|
||||
def reset(self):
|
||||
|
@ -1,16 +1,30 @@
|
||||
use nac3core::{
|
||||
codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator},
|
||||
codegen::{
|
||||
expr::gen_call,
|
||||
stmt::{gen_block, gen_with},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::DefinitionId,
|
||||
toplevel::{DefinitionId, GenCall},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||
|
||||
use inkwell::{context::Context, types::IntType, values::BasicValueEnum};
|
||||
use inkwell::{
|
||||
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
|
||||
};
|
||||
|
||||
use crate::timeline::TimeFns;
|
||||
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
collections::HashMap,
|
||||
convert::TryInto,
|
||||
hash::{Hash, Hasher},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub struct ArtiqCodeGenerator<'a> {
|
||||
name: String,
|
||||
size_t: u32,
|
||||
@ -21,16 +35,13 @@ pub struct ArtiqCodeGenerator<'a> {
|
||||
}
|
||||
|
||||
impl<'a> ArtiqCodeGenerator<'a> {
|
||||
pub fn new(name: String, size_t: u32, timeline: &'a (dyn TimeFns + Sync)) -> ArtiqCodeGenerator<'a> {
|
||||
pub fn new(
|
||||
name: String,
|
||||
size_t: u32,
|
||||
timeline: &'a (dyn TimeFns + Sync),
|
||||
) -> ArtiqCodeGenerator<'a> {
|
||||
assert!(size_t == 32 || size_t == 64);
|
||||
ArtiqCodeGenerator {
|
||||
name,
|
||||
size_t,
|
||||
name_counter: 0,
|
||||
start: None,
|
||||
end: None,
|
||||
timeline,
|
||||
}
|
||||
ArtiqCodeGenerator { name, size_t, name_counter: 0, start: None, end: None, timeline }
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,7 +97,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> bool {
|
||||
) {
|
||||
if let StmtKind::With { items, body, .. } = &stmt.node {
|
||||
if items.len() == 1 && items[0].optional_vars.is_none() {
|
||||
let item = &items[0];
|
||||
@ -108,9 +119,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
let old_start = self.start.take();
|
||||
let old_end = self.end.take();
|
||||
let now = if let Some(old_start) = &old_start {
|
||||
self.gen_expr(ctx, old_start)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, self)
|
||||
self.gen_expr(ctx, old_start).unwrap().to_basic_value_enum(ctx, self)
|
||||
} else {
|
||||
self.timeline.emit_now_mu(ctx)
|
||||
};
|
||||
@ -126,10 +135,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
let start_expr = Located {
|
||||
// location does not matter at this point
|
||||
location: stmt.location,
|
||||
node: ExprKind::Name {
|
||||
id: start,
|
||||
ctx: name_ctx.clone(),
|
||||
},
|
||||
node: ExprKind::Name { id: start, ctx: name_ctx.clone() },
|
||||
custom: Some(ctx.primitives.int64),
|
||||
};
|
||||
let start = self.gen_store_target(ctx, &start_expr);
|
||||
@ -140,40 +146,41 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
let end_expr = Located {
|
||||
// location does not matter at this point
|
||||
location: stmt.location,
|
||||
node: ExprKind::Name {
|
||||
id: end,
|
||||
ctx: name_ctx.clone(),
|
||||
},
|
||||
node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
|
||||
custom: Some(ctx.primitives.int64),
|
||||
};
|
||||
let end = self.gen_store_target(ctx, &end_expr);
|
||||
ctx.builder.build_store(end, now);
|
||||
self.end = Some(end_expr);
|
||||
self.name_counter += 1;
|
||||
let mut exited = false;
|
||||
for stmt in body.iter() {
|
||||
if self.gen_stmt(ctx, stmt) {
|
||||
exited = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
gen_block(self, ctx, body.iter());
|
||||
let current = ctx.builder.get_insert_block().unwrap();
|
||||
// if the current block is terminated, move before the terminator
|
||||
// we want to set the timeline before reaching the terminator
|
||||
// TODO: This may be unsound if there are multiple exit paths in the
|
||||
// block... e.g.
|
||||
// if ...:
|
||||
// return
|
||||
// Perhaps we can fix this by using actual with block?
|
||||
let reset_position = if let Some(terminator) = current.get_terminator() {
|
||||
ctx.builder.position_before(&terminator);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
// set duration
|
||||
let end_expr = self.end.take().unwrap();
|
||||
let end_val = self
|
||||
.gen_expr(ctx, &end_expr)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, self);
|
||||
let end_val =
|
||||
self.gen_expr(ctx, &end_expr).unwrap().to_basic_value_enum(ctx, self);
|
||||
|
||||
// inside an sequential block
|
||||
// inside a sequential block
|
||||
if old_start.is_none() {
|
||||
self.timeline.emit_at_mu(ctx, end_val);
|
||||
}
|
||||
// inside a parallel block, should update the outer max now_mu
|
||||
if let Some(old_end) = &old_end {
|
||||
let outer_end_val = self
|
||||
.gen_expr(ctx, old_end)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, self);
|
||||
let outer_end_val =
|
||||
self.gen_expr(ctx, old_end).unwrap().to_basic_value_enum(ctx, self);
|
||||
let smax =
|
||||
ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
|
||||
let i64 = ctx.ctx.i64_type();
|
||||
@ -194,24 +201,294 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
}
|
||||
self.start = old_start;
|
||||
self.end = old_end;
|
||||
return exited;
|
||||
if reset_position {
|
||||
ctx.builder.position_at_end(current);
|
||||
}
|
||||
return;
|
||||
} else if id == &"sequential".into() {
|
||||
let start = self.start.take();
|
||||
for stmt in body.iter() {
|
||||
if self.gen_stmt(ctx, stmt) {
|
||||
self.start = start;
|
||||
return true;
|
||||
self.gen_stmt(ctx, stmt);
|
||||
if ctx.is_terminated() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.start = start;
|
||||
return false;
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// not parallel/sequential
|
||||
gen_with(self, ctx, stmt)
|
||||
gen_with(self, ctx, stmt);
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec<u8>) {
|
||||
use nac3core::typecheck::typedef::TypeEnum::*;
|
||||
|
||||
let int32 = ctx.primitives.int32;
|
||||
let int64 = ctx.primitives.int64;
|
||||
let float = ctx.primitives.float;
|
||||
let bool = ctx.primitives.bool;
|
||||
let str = ctx.primitives.str;
|
||||
let none = ctx.primitives.none;
|
||||
|
||||
if ctx.unifier.unioned(ty, int32) {
|
||||
buffer.push(b'i');
|
||||
} else if ctx.unifier.unioned(ty, int64) {
|
||||
buffer.push(b'I');
|
||||
} else if ctx.unifier.unioned(ty, float) {
|
||||
buffer.push(b'f');
|
||||
} else if ctx.unifier.unioned(ty, bool) {
|
||||
buffer.push(b'b');
|
||||
} else if ctx.unifier.unioned(ty, str) {
|
||||
buffer.push(b's');
|
||||
} else if ctx.unifier.unioned(ty, none) {
|
||||
buffer.push(b'n');
|
||||
} else {
|
||||
let ty = ctx.unifier.get_ty(ty);
|
||||
match &*ty {
|
||||
TTuple { ty } => {
|
||||
buffer.push(b't');
|
||||
buffer.push(ty.len() as u8);
|
||||
for ty in ty {
|
||||
gen_rpc_tag(ctx, *ty, buffer);
|
||||
}
|
||||
}
|
||||
TList { ty } => {
|
||||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, *ty, buffer);
|
||||
}
|
||||
// we should return an error, this will be fixed after improving error message
|
||||
// as this requires returning an error during codegen
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rpc_codegen_callback_fn<'ctx, 'a>(
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic);
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let int8 = ctx.ctx.i8_type();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let service_id = int32.const_int(fun.1.0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
let mut tag = Vec::new();
|
||||
if obj.is_some() {
|
||||
tag.push(b'O');
|
||||
}
|
||||
for arg in fun.0.args.iter() {
|
||||
gen_rpc_tag(ctx, arg.ty, &mut tag);
|
||||
}
|
||||
tag.push(b':');
|
||||
gen_rpc_tag(ctx, fun.0.ret, &mut tag);
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tag.hash(&mut hasher);
|
||||
let hash = format!("{}", hasher.finish());
|
||||
|
||||
let tag_ptr = ctx
|
||||
.module
|
||||
.get_global(hash.as_str())
|
||||
.unwrap_or_else(|| {
|
||||
let tag_arr_ptr = ctx.module.add_global(
|
||||
int8.array_type(tag.len() as u32),
|
||||
None,
|
||||
format!("tagptr{}", fun.1 .0).as_str(),
|
||||
);
|
||||
tag_arr_ptr.set_initializer(&int8.const_array(
|
||||
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(),
|
||||
));
|
||||
tag_arr_ptr.set_linkage(Linkage::Private);
|
||||
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
||||
tag_ptr.set_linkage(Linkage::Private);
|
||||
tag_ptr.set_initializer(&ctx.ctx.const_struct(
|
||||
&[
|
||||
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
|
||||
size_type.const_int(tag.len() as u64, false).into(),
|
||||
],
|
||||
false,
|
||||
));
|
||||
tag_ptr
|
||||
})
|
||||
.as_pointer_value();
|
||||
|
||||
let arg_length = args.len() + if obj.is_some() { 1 } else { 0 };
|
||||
|
||||
let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| {
|
||||
ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None)
|
||||
});
|
||||
let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"llvm.stackrestore",
|
||||
ctx.ctx.void_type().fn_type(&[ptr_type.into()], false),
|
||||
None,
|
||||
)
|
||||
});
|
||||
|
||||
let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack");
|
||||
let args_ptr = ctx.builder.build_array_alloca(
|
||||
ptr_type,
|
||||
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
||||
"argptr",
|
||||
);
|
||||
|
||||
// -- rpc args handling
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
for (key, value) in args.into_iter() {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
}
|
||||
// default value handling
|
||||
for k in keys.into_iter() {
|
||||
mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into());
|
||||
}
|
||||
// reorder the parameters
|
||||
let mut real_params = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator))
|
||||
.collect::<Vec<_>>();
|
||||
if let Some(obj) = obj {
|
||||
if let ValueEnum::Static(obj) = obj.1 {
|
||||
real_params.insert(0, obj.get_const_obj(ctx, generator));
|
||||
} else {
|
||||
// should be an error here...
|
||||
panic!("only host object is allowed");
|
||||
}
|
||||
}
|
||||
|
||||
for (i, arg) in real_params.iter().enumerate() {
|
||||
let arg_slot = if arg.is_pointer_value() {
|
||||
arg.into_pointer_value()
|
||||
} else {
|
||||
let arg_slot = ctx.builder.build_alloca(arg.get_type(), &format!("rpc.arg{}", i));
|
||||
ctx.builder.build_store(arg_slot, *arg);
|
||||
arg_slot
|
||||
};
|
||||
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg");
|
||||
let arg_ptr = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
args_ptr,
|
||||
&[int32.const_int(i as u64, false)],
|
||||
&format!("rpc.arg{}", i),
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(arg_ptr, arg_slot);
|
||||
}
|
||||
|
||||
// call
|
||||
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"rpc_send",
|
||||
ctx.ctx.void_type().fn_type(
|
||||
&[
|
||||
int32.into(),
|
||||
tag_ptr_type.ptr_type(AddressSpace::Generic).into(),
|
||||
ptr_type.ptr_type(AddressSpace::Generic).into(),
|
||||
],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
)
|
||||
});
|
||||
ctx.builder.build_call(
|
||||
rpc_send,
|
||||
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
||||
"rpc.send",
|
||||
);
|
||||
|
||||
// reclaim stack space used by arguments
|
||||
ctx.builder.build_call(
|
||||
stackrestore,
|
||||
&[stackptr.try_as_basic_value().unwrap_left().into()],
|
||||
"rpc.stackrestore",
|
||||
);
|
||||
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
// void *ret_ptr = alloca(sizeof(T));
|
||||
// void *ptr = ret_ptr;
|
||||
// loop: int size = rpc_recv(ptr);
|
||||
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
// if(size) { ptr = alloca(size); goto loop; }
|
||||
// else *(T*)ret_ptr
|
||||
// }
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
|
||||
});
|
||||
|
||||
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
|
||||
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
|
||||
return None
|
||||
}
|
||||
|
||||
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_function = prehead_bb.get_parent().unwrap();
|
||||
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
||||
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
||||
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
||||
|
||||
let mut ret_ty = ctx.get_llvm_type(generator, fun.0.ret);
|
||||
let need_load = !ret_ty.is_pointer_type();
|
||||
if ret_ty.is_pointer_type() {
|
||||
ret_ty = ret_ty.into_pointer_type().get_element_type().try_into().unwrap();
|
||||
}
|
||||
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot");
|
||||
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr");
|
||||
ctx.builder.build_unconditional_branch(head_bb);
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr");
|
||||
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let is_done = ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::EQ,
|
||||
int32.const_zero(),
|
||||
alloc_size,
|
||||
"rpc.done",
|
||||
);
|
||||
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb);
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
|
||||
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc");
|
||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr");
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb);
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
|
||||
if need_load {
|
||||
let result = ctx.builder.build_load(slot, "rpc.result");
|
||||
ctx.builder.build_call(
|
||||
stackrestore,
|
||||
&[stackptr.try_as_basic_value().unwrap_left().into()],
|
||||
"rpc.stackrestore",
|
||||
);
|
||||
Some(result)
|
||||
} else {
|
||||
Some(slot.into())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rpc_codegen_callback() -> Arc<GenCall> {
|
||||
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
|
||||
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
|
||||
})))
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ use inkwell::{
|
||||
};
|
||||
use nac3core::typecheck::typedef::{Unifier, TypeEnum};
|
||||
use nac3parser::{
|
||||
ast::{self, Stmt, StrRef},
|
||||
ast::{self, ExprKind, Stmt, StmtKind, StrRef},
|
||||
parser::{self, parse_program},
|
||||
};
|
||||
use pyo3::prelude::*;
|
||||
@ -24,7 +24,10 @@ use nac3core::{
|
||||
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
|
||||
codegen::irrt::load_irrt,
|
||||
symbol_resolver::SymbolResolver,
|
||||
toplevel::{composer::{TopLevelComposer, ComposerConfig}, DefinitionId, GenCall, TopLevelDef},
|
||||
toplevel::{
|
||||
composer::{ComposerConfig, TopLevelComposer},
|
||||
DefinitionId, GenCall, TopLevelDef,
|
||||
},
|
||||
typecheck::typedef::{FunSignature, FuncArg},
|
||||
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
|
||||
};
|
||||
@ -32,7 +35,7 @@ use nac3core::{
|
||||
use tempfile::{self, TempDir};
|
||||
|
||||
use crate::{
|
||||
codegen::ArtiqCodeGenerator,
|
||||
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
|
||||
symbol_resolver::{InnerResolver, PythonHelper, Resolver},
|
||||
};
|
||||
|
||||
@ -61,6 +64,7 @@ pub struct PrimitivePythonId {
|
||||
tuple: u64,
|
||||
typevar: u64,
|
||||
none: u64,
|
||||
exception: u64,
|
||||
generic_alias: (u64, u64),
|
||||
virtual_id: u64,
|
||||
}
|
||||
@ -81,6 +85,7 @@ struct Nac3 {
|
||||
primitive_ids: PrimitivePythonId,
|
||||
working_directory: TempDir,
|
||||
top_levels: Vec<TopLevelComponent>,
|
||||
string_store: Arc<RwLock<HashMap<String, i32>>>,
|
||||
}
|
||||
|
||||
impl Nac3 {
|
||||
@ -127,9 +132,13 @@ impl Nac3 {
|
||||
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||
match &base.node {
|
||||
ast::ExprKind::Name { id, .. } => {
|
||||
let base_obj = module.getattr(py, id.to_string())?;
|
||||
let base_id = id_fn.call1((base_obj,))?.extract()?;
|
||||
Ok(registered_class_ids.contains(&base_id))
|
||||
if *id == "Exception".into() {
|
||||
Ok(true)
|
||||
} else {
|
||||
let base_obj = module.getattr(py, id.to_string())?;
|
||||
let base_id = id_fn.call1((base_obj,))?.extract()?;
|
||||
Ok(registered_class_ids.contains(&base_id))
|
||||
}
|
||||
}
|
||||
_ => Ok(true),
|
||||
}
|
||||
@ -143,7 +152,9 @@ impl Nac3 {
|
||||
{
|
||||
decorator_list.iter().any(|decorator| {
|
||||
if let ast::ExprKind::Name { id, .. } = decorator.node {
|
||||
id.to_string() == "kernel" || id.to_string() == "portable"
|
||||
id.to_string() == "kernel"
|
||||
|| id.to_string() == "portable"
|
||||
|| id.to_string() == "rpc"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@ -159,7 +170,7 @@ impl Nac3 {
|
||||
} => decorator_list.iter().any(|decorator| {
|
||||
if let ast::ExprKind::Name { id, .. } = decorator.node {
|
||||
let id = id.to_string();
|
||||
id == "extern" || id == "portable" || id == "kernel"
|
||||
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@ -188,7 +199,7 @@ impl Nac3 {
|
||||
Ok(ty) => ty,
|
||||
Err(e) => return Some(format!("type error inside object launching kernel: {}", e))
|
||||
};
|
||||
|
||||
|
||||
let fun_ty = if method_name.is_empty() {
|
||||
base_ty
|
||||
} else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) {
|
||||
@ -201,7 +212,7 @@ impl Nac3 {
|
||||
} else {
|
||||
return Some("cannot launch kernel by calling a non-callable".into())
|
||||
};
|
||||
|
||||
|
||||
if let TypeEnum::TFunc(sig) = &*unifier.get_ty(fun_ty) {
|
||||
let FunSignature { args, .. } = &*sig.borrow();
|
||||
if arg_names.len() > args.len() {
|
||||
@ -269,7 +280,7 @@ impl Nac3 {
|
||||
ret: primitive.int64,
|
||||
vars: HashMap::new(),
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _| {
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
|
||||
Some(time_fns.emit_now_mu(ctx))
|
||||
}))),
|
||||
),
|
||||
@ -284,8 +295,9 @@ impl Nac3 {
|
||||
ret: primitive.none,
|
||||
vars: HashMap::new(),
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| {
|
||||
time_fns.emit_at_mu(ctx, args[0].1);
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| {
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
|
||||
time_fns.emit_at_mu(ctx, arg);
|
||||
None
|
||||
}))),
|
||||
),
|
||||
@ -300,16 +312,20 @@ impl Nac3 {
|
||||
ret: primitive.none,
|
||||
vars: HashMap::new(),
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| {
|
||||
time_fns.emit_delay_mu(ctx, args[0].1);
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| {
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator);
|
||||
time_fns.emit_delay_mu(ctx, arg);
|
||||
None
|
||||
}))),
|
||||
),
|
||||
];
|
||||
let (_, builtins_def, builtins_ty) = TopLevelComposer::new(builtins.clone(), ComposerConfig {
|
||||
kernel_ann: Some("Kernel"),
|
||||
kernel_invariant_ann: "KernelInvariant"
|
||||
});
|
||||
let (_, builtins_def, builtins_ty) = TopLevelComposer::new(
|
||||
builtins.clone(),
|
||||
ComposerConfig {
|
||||
kernel_ann: Some("Kernel"),
|
||||
kernel_invariant_ann: "KernelInvariant",
|
||||
},
|
||||
);
|
||||
|
||||
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
||||
let id_fn = builtins_mod.getattr("id").unwrap();
|
||||
@ -385,6 +401,11 @@ impl Nac3 {
|
||||
.unwrap()
|
||||
.extract()
|
||||
.unwrap(),
|
||||
exception: id_fn
|
||||
.call1((builtins_mod.getattr("tuple").unwrap(),))
|
||||
.unwrap()
|
||||
.extract()
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
||||
@ -405,6 +426,7 @@ impl Nac3 {
|
||||
top_levels: Default::default(),
|
||||
pyid_to_def: Default::default(),
|
||||
working_directory,
|
||||
string_store: Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
@ -441,6 +463,7 @@ impl Nac3 {
|
||||
method_name: &str,
|
||||
args: Vec<&PyAny>,
|
||||
filename: &str,
|
||||
embedding_map: &PyAny,
|
||||
py: Python,
|
||||
) -> PyResult<()> {
|
||||
let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig {
|
||||
@ -451,17 +474,26 @@ impl Nac3 {
|
||||
let builtins = PyModule::import(py, "builtins")?;
|
||||
let typings = PyModule::import(py, "typing")?;
|
||||
let id_fn = builtins.getattr("id")?;
|
||||
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
|
||||
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
|
||||
let store_fun = embedding_map
|
||||
.getattr("store_function")
|
||||
.unwrap()
|
||||
.to_object(py);
|
||||
let helper = PythonHelper {
|
||||
id_fn: builtins.getattr("id").unwrap().to_object(py),
|
||||
len_fn: builtins.getattr("len").unwrap().to_object(py),
|
||||
type_fn: builtins.getattr("type").unwrap().to_object(py),
|
||||
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
|
||||
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
|
||||
store_obj,
|
||||
store_str
|
||||
};
|
||||
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
|
||||
|
||||
let pyid_to_type = Arc::new(RwLock::new(HashMap::<u64, Type>::new()));
|
||||
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
|
||||
let mut rpc_ids = vec![];
|
||||
for (stmt, path, module) in self.top_levels.iter() {
|
||||
let py_module: &PyAny = module.extract(py)?;
|
||||
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
|
||||
@ -492,6 +524,7 @@ impl Nac3 {
|
||||
id_to_primitive: Default::default(),
|
||||
field_to_val: Default::default(),
|
||||
helper,
|
||||
string_store: self.string_store.clone(),
|
||||
})))
|
||||
as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let name_to_pyid = Rc::new(name_to_pyid);
|
||||
@ -502,7 +535,30 @@ impl Nac3 {
|
||||
|
||||
let (name, def_id, ty) = composer
|
||||
.register_top_level(stmt.clone(), Some(resolver.clone()), path.clone())
|
||||
.map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e))
|
||||
})?;
|
||||
|
||||
match &stmt.node {
|
||||
StmtKind::FunctionDef { decorator_list, .. } => {
|
||||
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
||||
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string()).unwrap())).unwrap();
|
||||
rpc_ids.push((None, def_id));
|
||||
}
|
||||
}
|
||||
StmtKind::ClassDef { name, body, .. } => {
|
||||
let class_obj = module.getattr(py, name.to_string()).unwrap();
|
||||
for stmt in body.iter() {
|
||||
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
||||
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
|
||||
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => ()
|
||||
}
|
||||
|
||||
let id = *name_to_pyid.get(&name).unwrap();
|
||||
self.pyid_to_def.write().insert(id, def_id);
|
||||
{
|
||||
@ -552,6 +608,7 @@ impl Nac3 {
|
||||
name_to_pyid,
|
||||
module: module.to_object(py),
|
||||
helper,
|
||||
string_store: self.string_store.clone(),
|
||||
}))) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let (_, def_id, _) = composer
|
||||
.register_top_level(
|
||||
@ -595,6 +652,45 @@ impl Nac3 {
|
||||
}
|
||||
}
|
||||
let top_level = Arc::new(composer.make_top_level_context());
|
||||
|
||||
{
|
||||
let rpc_codegen = rpc_codegen_callback();
|
||||
let defs = top_level.definitions.read();
|
||||
for (class_data, id) in rpc_ids.iter() {
|
||||
let mut def = defs[id.0].write();
|
||||
match &mut *def {
|
||||
TopLevelDef::Function {
|
||||
codegen_callback, ..
|
||||
} => {
|
||||
*codegen_callback = Some(rpc_codegen.clone());
|
||||
}
|
||||
TopLevelDef::Class { methods, .. } => {
|
||||
let (class_def, method_name) = class_data.as_ref().unwrap();
|
||||
for (name, _, id) in methods.iter() {
|
||||
if name != method_name {
|
||||
continue;
|
||||
}
|
||||
if let TopLevelDef::Function {
|
||||
codegen_callback, ..
|
||||
} = &mut *defs[id.0].write()
|
||||
{
|
||||
*codegen_callback = Some(rpc_codegen.clone());
|
||||
store_fun
|
||||
.call1(
|
||||
py,
|
||||
(
|
||||
id.0.into_py(py),
|
||||
class_def.getattr(py, name.to_string()).unwrap(),
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let instance = {
|
||||
let defs = top_level.definitions.read();
|
||||
let mut definition = defs[def_id.0].write();
|
||||
@ -634,15 +730,17 @@ impl Nac3 {
|
||||
let buffer = buffer.as_slice().into();
|
||||
membuffer.lock().push(buffer);
|
||||
})));
|
||||
let size_t = if self.isa == Isa::Host {
|
||||
64
|
||||
} else {
|
||||
32
|
||||
};
|
||||
let size_t = if self.isa == Isa::Host { 64 } else { 32 };
|
||||
let thread_names: Vec<String> = (0..4).map(|_| "main".to_string()).collect();
|
||||
let threads: Vec<_> = thread_names
|
||||
.iter()
|
||||
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
|
||||
.map(|s| {
|
||||
Box::new(ArtiqCodeGenerator::new(
|
||||
s.to_string(),
|
||||
size_t,
|
||||
self.time_fns,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
py.allow_threads(|| {
|
||||
@ -759,11 +857,12 @@ impl Nac3 {
|
||||
obj: &PyAny,
|
||||
method_name: &str,
|
||||
args: Vec<&PyAny>,
|
||||
embedding_map: &PyAny,
|
||||
py: Python,
|
||||
) -> PyResult<PyObject> {
|
||||
let filename_path = self.working_directory.path().join("module.elf");
|
||||
let filename = filename_path.to_str().unwrap();
|
||||
self.compile_method_to_file(obj, method_name, args, filename, py)?;
|
||||
self.compile_method_to_file(obj, method_name, args, filename, embedding_map, py)?;
|
||||
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
|
||||
}
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ pub struct InnerResolver {
|
||||
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
|
||||
pub primitive_ids: PrimitivePythonId,
|
||||
pub helper: PythonHelper,
|
||||
pub string_store: Arc<RwLock<HashMap<String, i32>>>,
|
||||
// module specific
|
||||
pub name_to_pyid: HashMap<StrRef, u64>,
|
||||
pub module: PyObject,
|
||||
@ -56,11 +57,14 @@ pub struct PythonHelper {
|
||||
pub id_fn: PyObject,
|
||||
pub origin_ty_fn: PyObject,
|
||||
pub args_ty_fn: PyObject,
|
||||
pub store_obj: PyObject,
|
||||
pub store_str: PyObject,
|
||||
}
|
||||
|
||||
struct PythonValue {
|
||||
id: u64,
|
||||
value: PyObject,
|
||||
store_obj: PyObject,
|
||||
resolver: Arc<InnerResolver>,
|
||||
}
|
||||
|
||||
@ -69,6 +73,36 @@ impl StaticValue for PythonValue {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn get_const_obj<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
_: &mut dyn CodeGenerator,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
ctx.module
|
||||
.get_global(self.id.to_string().as_str())
|
||||
.map(|val| val.as_pointer_value().into())
|
||||
.unwrap_or_else(|| {
|
||||
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
|
||||
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
|
||||
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
|
||||
let global =
|
||||
ctx.module
|
||||
.add_global(struct_type, None, format!("{}_const", self.id).as_str());
|
||||
global.set_constant(true);
|
||||
global.set_initializer(&ctx.ctx.const_struct(
|
||||
&[ctx.ctx.i32_type().const_int(id as u64, false).into()],
|
||||
false,
|
||||
));
|
||||
let global2 =
|
||||
ctx.module
|
||||
.add_global(struct_type.ptr_type(AddressSpace::Generic), None, format!("{}_const2", self.id).as_str());
|
||||
global2.set_initializer(&global.as_pointer_value());
|
||||
Ok(global2.as_pointer_value().into())
|
||||
})
|
||||
.unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn to_basic_value_enum<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
@ -140,6 +174,7 @@ impl StaticValue for PythonValue {
|
||||
ValueEnum::Static(Arc::new(PythonValue {
|
||||
id,
|
||||
value: obj,
|
||||
store_obj: self.store_obj.clone(),
|
||||
resolver: self.resolver.clone(),
|
||||
}))
|
||||
})
|
||||
@ -208,7 +243,9 @@ impl InnerResolver {
|
||||
Ok(Ok((primitives.bool, true)))
|
||||
} else if ty_id == self.primitive_ids.float {
|
||||
Ok(Ok((primitives.float, true)))
|
||||
} else if ty_id == self.primitive_ids.list {
|
||||
} else if ty_id == self.primitive_ids.exception {
|
||||
Ok(Ok((primitives.exception, true)))
|
||||
}else if ty_id == self.primitive_ids.list {
|
||||
// do not handle type var param and concrete check here
|
||||
let var = unifier.get_fresh_var().0;
|
||||
let list = unifier.add_ty(TypeEnum::TList { ty: var });
|
||||
@ -755,9 +792,7 @@ impl InnerResolver {
|
||||
.get_llvm_type(generator, ty)
|
||||
.into_pointer_type()
|
||||
.get_element_type()
|
||||
.into_struct_type()
|
||||
.as_basic_type_enum();
|
||||
|
||||
.into_struct_type();
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
@ -783,7 +818,7 @@ impl InnerResolver {
|
||||
.collect();
|
||||
let values = values?;
|
||||
if let Some(values) = values {
|
||||
let val = ctx.ctx.const_struct(&values, false);
|
||||
let val = ty.const_named_struct(&values);
|
||||
let global = ctx
|
||||
.module
|
||||
.add_global(ty, Some(AddressSpace::Generic), &id_str);
|
||||
@ -948,6 +983,7 @@ impl SymbolResolver for Resolver {
|
||||
ValueEnum::Static(Arc::new(PythonValue {
|
||||
id,
|
||||
value: v,
|
||||
store_obj: self.0.helper.store_obj.clone(),
|
||||
resolver: self.0.clone(),
|
||||
}))
|
||||
})
|
||||
@ -971,4 +1007,17 @@ impl SymbolResolver for Resolver {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
fn get_string_id(&self, s: &str) -> i32 {
|
||||
let mut string_store = self.0.string_store.write();
|
||||
if let Some(id) = string_store.get(s) {
|
||||
*id
|
||||
} else {
|
||||
let id = Python::with_gil(|py| -> PyResult<i32> {
|
||||
self.0.helper.store_str.call1(py, (s, ))?.extract(py)
|
||||
}).unwrap();
|
||||
string_store.insert(s.into(), id);
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -32,10 +32,16 @@ pub enum SymbolValue {
|
||||
pub trait StaticValue {
|
||||
fn get_unique_identifier(&self) -> u64;
|
||||
|
||||
fn get_const_obj<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> BasicValueEnum<'ctx>;
|
||||
|
||||
fn to_basic_value_enum<'ctx, 'a>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
generator: &mut dyn CodeGenerator
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> BasicValueEnum<'ctx>;
|
||||
|
||||
fn get_field<'ctx, 'a>(
|
||||
|
Loading…
Reference in New Issue
Block a user