From bf52e294ee8e902e8d8053f37ea346f058c28764 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 12 Feb 2022 21:17:37 +0800 Subject: [PATCH] nac3artiq: RPC support --- nac3artiq/demo/embedding_map.py | 57 +++++ nac3artiq/demo/min_artiq.py | 12 +- nac3artiq/src/codegen.rs | 369 +++++++++++++++++++++++++++---- nac3artiq/src/lib.rs | 153 ++++++++++--- nac3artiq/src/symbol_resolver.rs | 59 ++++- nac3core/src/symbol_resolver.rs | 8 +- 6 files changed, 577 insertions(+), 81 deletions(-) create mode 100644 nac3artiq/demo/embedding_map.py diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py new file mode 100644 index 00000000..9f4b877a --- /dev/null +++ b/nac3artiq/demo/embedding_map.py @@ -0,0 +1,57 @@ +class EmbeddingMap: + def __init__(self): + self.object_inverse_map = {} + self.object_map = {} + self.string_map = {} + self.string_reverse_map = {} + self.function_map = {} + + # preallocate exception names + self.preallocate_runtime_exception_names(["RuntimeError", + "RTIOUnderflow", + "RTIOOverflow", + "RTIODestinationUnreachable", + "DMAError", + "I2CError", + "CacheError", + "SPIError", + "0:ZeroDivisionError", + "0:IndexError"]) + + def preallocate_runtime_exception_names(self, names): + for i, name in enumerate(names): + if ":" not in name: + name = "0:artiq.coredevice.exceptions." + name + exn_id = self.store_str(name) + assert exn_id == i + + def store_function(self, key, fun): + self.function_map[key] = fun + return key + + def store_object(self, obj): + obj_id = id(obj) + if obj_id in self.object_inverse_map: + return self.object_inverse_map[obj_id] + key = len(self.object_map) + self.object_map[key] = obj + self.object_inverse_map[obj_id] = key + return key + + def store_str(self, s): + if s in self.string_reverse_map: + return self.string_reverse_map[s] + key = len(self.string_map) + self.string_map[key] = s + self.string_reverse_map[s] = key + return key + + def retrieve_function(self, key): + return self.function_map[key] + + def retrieve_object(self, key): + return self.object_map[key] + + def retrieve_str(self, key): + return self.string_map[key] + diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 64f41c73..5302cf19 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -6,13 +6,14 @@ from typing import Generic, TypeVar from math import floor, ceil import nac3artiq +from embedding_map import EmbeddingMap __all__ = [ "Kernel", "KernelInvariant", "virtual", "round64", "floor64", "ceil64", "extern", "kernel", "portable", "nac3", - "ms", "us", "ns", + "rpc", "ms", "us", "ns", "print_int32", "print_int64", "Core", "TTLOut", "parallel", "sequential" @@ -65,6 +66,10 @@ def extern(function): register_function(function) return function +def rpc(function): + """Decorates a function declaration defined by the core device runtime.""" + register_function(function) + return function def kernel(function_or_method): """Decorates a function or method to be executed on the core device.""" @@ -146,6 +151,9 @@ class Core: def run(self, method, *args, **kwargs): global allow_registration + + embedding = EmbeddingMap() + if allow_registration: compiler.analyze(registered_functions, registered_classes) allow_registration = False @@ -157,7 +165,7 @@ class Core: obj = method name = "" - compiler.compile_method_to_file(obj, name, args, "module.elf") + compiler.compile_method_to_file(obj, name, args, "module.elf", embedding) @kernel def reset(self): diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 38d9ecd3..a799fa40 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,16 +1,30 @@ use nac3core::{ - codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator}, + codegen::{ + expr::gen_call, + stmt::{gen_block, gen_with}, + CodeGenContext, CodeGenerator, + }, symbol_resolver::ValueEnum, - toplevel::DefinitionId, + toplevel::{DefinitionId, GenCall}, typecheck::typedef::{FunSignature, Type}, }; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; -use inkwell::{context::Context, types::IntType, values::BasicValueEnum}; +use inkwell::{ + context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace, +}; use crate::timeline::TimeFns; +use std::{ + collections::hash_map::DefaultHasher, + collections::HashMap, + convert::TryInto, + hash::{Hash, Hasher}, + sync::Arc, +}; + pub struct ArtiqCodeGenerator<'a> { name: String, size_t: u32, @@ -21,16 +35,13 @@ pub struct ArtiqCodeGenerator<'a> { } impl<'a> ArtiqCodeGenerator<'a> { - pub fn new(name: String, size_t: u32, timeline: &'a (dyn TimeFns + Sync)) -> ArtiqCodeGenerator<'a> { + pub fn new( + name: String, + size_t: u32, + timeline: &'a (dyn TimeFns + Sync), + ) -> ArtiqCodeGenerator<'a> { assert!(size_t == 32 || size_t == 64); - ArtiqCodeGenerator { - name, - size_t, - name_counter: 0, - start: None, - end: None, - timeline, - } + ArtiqCodeGenerator { name, size_t, name_counter: 0, start: None, end: None, timeline } } } @@ -86,7 +97,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmt: &Stmt>, - ) -> bool { + ) { if let StmtKind::With { items, body, .. } = &stmt.node { if items.len() == 1 && items[0].optional_vars.is_none() { let item = &items[0]; @@ -108,9 +119,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { let old_start = self.start.take(); let old_end = self.end.take(); let now = if let Some(old_start) = &old_start { - self.gen_expr(ctx, old_start) - .unwrap() - .to_basic_value_enum(ctx, self) + self.gen_expr(ctx, old_start).unwrap().to_basic_value_enum(ctx, self) } else { self.timeline.emit_now_mu(ctx) }; @@ -126,10 +135,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { let start_expr = Located { // location does not matter at this point location: stmt.location, - node: ExprKind::Name { - id: start, - ctx: name_ctx.clone(), - }, + node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, custom: Some(ctx.primitives.int64), }; let start = self.gen_store_target(ctx, &start_expr); @@ -140,40 +146,41 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { let end_expr = Located { // location does not matter at this point location: stmt.location, - node: ExprKind::Name { - id: end, - ctx: name_ctx.clone(), - }, + node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, custom: Some(ctx.primitives.int64), }; let end = self.gen_store_target(ctx, &end_expr); ctx.builder.build_store(end, now); self.end = Some(end_expr); self.name_counter += 1; - let mut exited = false; - for stmt in body.iter() { - if self.gen_stmt(ctx, stmt) { - exited = true; - break; - } - } + gen_block(self, ctx, body.iter()); + let current = ctx.builder.get_insert_block().unwrap(); + // if the current block is terminated, move before the terminator + // we want to set the timeline before reaching the terminator + // TODO: This may be unsound if there are multiple exit paths in the + // block... e.g. + // if ...: + // return + // Perhaps we can fix this by using actual with block? + let reset_position = if let Some(terminator) = current.get_terminator() { + ctx.builder.position_before(&terminator); + true + } else { + false + }; // set duration let end_expr = self.end.take().unwrap(); - let end_val = self - .gen_expr(ctx, &end_expr) - .unwrap() - .to_basic_value_enum(ctx, self); + let end_val = + self.gen_expr(ctx, &end_expr).unwrap().to_basic_value_enum(ctx, self); - // inside an sequential block + // inside a sequential block if old_start.is_none() { self.timeline.emit_at_mu(ctx, end_val); } // inside a parallel block, should update the outer max now_mu if let Some(old_end) = &old_end { - let outer_end_val = self - .gen_expr(ctx, old_end) - .unwrap() - .to_basic_value_enum(ctx, self); + let outer_end_val = + self.gen_expr(ctx, old_end).unwrap().to_basic_value_enum(ctx, self); let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let i64 = ctx.ctx.i64_type(); @@ -194,24 +201,294 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { } self.start = old_start; self.end = old_end; - return exited; + if reset_position { + ctx.builder.position_at_end(current); + } + return; } else if id == &"sequential".into() { let start = self.start.take(); for stmt in body.iter() { - if self.gen_stmt(ctx, stmt) { - self.start = start; - return true; + self.gen_stmt(ctx, stmt); + if ctx.is_terminated() { + break; } } self.start = start; - return false; + return } } } // not parallel/sequential - gen_with(self, ctx, stmt) + gen_with(self, ctx, stmt); } else { unreachable!() } } } + +fn gen_rpc_tag<'ctx, 'a>(ctx: &mut CodeGenContext<'ctx, 'a>, ty: Type, buffer: &mut Vec) { + use nac3core::typecheck::typedef::TypeEnum::*; + + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let float = ctx.primitives.float; + let bool = ctx.primitives.bool; + let str = ctx.primitives.str; + let none = ctx.primitives.none; + + if ctx.unifier.unioned(ty, int32) { + buffer.push(b'i'); + } else if ctx.unifier.unioned(ty, int64) { + buffer.push(b'I'); + } else if ctx.unifier.unioned(ty, float) { + buffer.push(b'f'); + } else if ctx.unifier.unioned(ty, bool) { + buffer.push(b'b'); + } else if ctx.unifier.unioned(ty, str) { + buffer.push(b's'); + } else if ctx.unifier.unioned(ty, none) { + buffer.push(b'n'); + } else { + let ty = ctx.unifier.get_ty(ty); + match &*ty { + TTuple { ty } => { + buffer.push(b't'); + buffer.push(ty.len() as u8); + for ty in ty { + gen_rpc_tag(ctx, *ty, buffer); + } + } + TList { ty } => { + buffer.push(b'l'); + gen_rpc_tag(ctx, *ty, buffer); + } + // we should return an error, this will be fixed after improving error message + // as this requires returning an error during codegen + _ => unimplemented!(), + } + } +} + +fn rpc_codegen_callback_fn<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Option> { + let ptr_type = ctx.ctx.i8_type().ptr_type(inkwell::AddressSpace::Generic); + let size_type = generator.get_size_type(ctx.ctx); + let int8 = ctx.ctx.i8_type(); + let int32 = ctx.ctx.i32_type(); + let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); + + let service_id = int32.const_int(fun.1.0 as u64, false); + // -- setup rpc tags + let mut tag = Vec::new(); + if obj.is_some() { + tag.push(b'O'); + } + for arg in fun.0.args.iter() { + gen_rpc_tag(ctx, arg.ty, &mut tag); + } + tag.push(b':'); + gen_rpc_tag(ctx, fun.0.ret, &mut tag); + + let mut hasher = DefaultHasher::new(); + tag.hash(&mut hasher); + let hash = format!("{}", hasher.finish()); + + let tag_ptr = ctx + .module + .get_global(hash.as_str()) + .unwrap_or_else(|| { + let tag_arr_ptr = ctx.module.add_global( + int8.array_type(tag.len() as u32), + None, + format!("tagptr{}", fun.1 .0).as_str(), + ); + tag_arr_ptr.set_initializer(&int8.const_array( + &tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::>(), + )); + tag_arr_ptr.set_linkage(Linkage::Private); + let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); + tag_ptr.set_linkage(Linkage::Private); + tag_ptr.set_initializer(&ctx.ctx.const_struct( + &[ + tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), + size_type.const_int(tag.len() as u64, false).into(), + ], + false, + )); + tag_ptr + }) + .as_pointer_value(); + + let arg_length = args.len() + if obj.is_some() { 1 } else { 0 }; + + let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| { + ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None) + }); + let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| { + ctx.module.add_function( + "llvm.stackrestore", + ctx.ctx.void_type().fn_type(&[ptr_type.into()], false), + None, + ) + }); + + let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack"); + let args_ptr = ctx.builder.build_array_alloca( + ptr_type, + ctx.ctx.i32_type().const_int(arg_length as u64, false), + "argptr", + ); + + // -- rpc args handling + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in args.into_iter() { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys.into_iter() { + mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into()); + } + // reorder the parameters + let mut real_params = fun + .0 + .args + .iter() + .map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator)) + .collect::>(); + if let Some(obj) = obj { + if let ValueEnum::Static(obj) = obj.1 { + real_params.insert(0, obj.get_const_obj(ctx, generator)); + } else { + // should be an error here... + panic!("only host object is allowed"); + } + } + + for (i, arg) in real_params.iter().enumerate() { + let arg_slot = if arg.is_pointer_value() { + arg.into_pointer_value() + } else { + let arg_slot = ctx.builder.build_alloca(arg.get_type(), &format!("rpc.arg{}", i)); + ctx.builder.build_store(arg_slot, *arg); + arg_slot + }; + let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg"); + let arg_ptr = unsafe { + ctx.builder.build_gep( + args_ptr, + &[int32.const_int(i as u64, false)], + &format!("rpc.arg{}", i), + ) + }; + ctx.builder.build_store(arg_ptr, arg_slot); + } + + // call + let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { + ctx.module.add_function( + "rpc_send", + ctx.ctx.void_type().fn_type( + &[ + int32.into(), + tag_ptr_type.ptr_type(AddressSpace::Generic).into(), + ptr_type.ptr_type(AddressSpace::Generic).into(), + ], + false, + ), + None, + ) + }); + ctx.builder.build_call( + rpc_send, + &[service_id.into(), tag_ptr.into(), args_ptr.into()], + "rpc.send", + ); + + // reclaim stack space used by arguments + ctx.builder.build_call( + stackrestore, + &[stackptr.try_as_basic_value().unwrap_left().into()], + "rpc.stackrestore", + ); + + // -- receive value: + // T result = { + // void *ret_ptr = alloca(sizeof(T)); + // void *ptr = ret_ptr; + // loop: int size = rpc_recv(ptr); + // // Non-zero: Provide `size` bytes of extra storage for variable-length data. + // if(size) { ptr = alloca(size); goto loop; } + // else *(T*)ret_ptr + // } + let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { + ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None) + }); + + if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { + ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); + return None + } + + let prehead_bb = ctx.builder.get_insert_block().unwrap(); + let current_function = prehead_bb.get_parent().unwrap(); + let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); + let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); + let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); + + let mut ret_ty = ctx.get_llvm_type(generator, fun.0.ret); + let need_load = !ret_ty.is_pointer_type(); + if ret_ty.is_pointer_type() { + ret_ty = ret_ty.into_pointer_type().get_element_type().try_into().unwrap(); + } + let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot"); + let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr"); + ctx.builder.build_unconditional_branch(head_bb); + ctx.builder.position_at_end(head_bb); + + let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr"); + phi.add_incoming(&[(&slotgen, prehead_bb)]); + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .unwrap() + .into_int_value(); + let is_done = ctx.builder.build_int_compare( + inkwell::IntPredicate::EQ, + int32.const_zero(), + alloc_size, + "rpc.done", + ); + + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb); + ctx.builder.position_at_end(alloc_bb); + + let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc"); + let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr"); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + ctx.builder.build_unconditional_branch(head_bb); + + ctx.builder.position_at_end(tail_bb); + + if need_load { + let result = ctx.builder.build_load(slot, "rpc.result"); + ctx.builder.build_call( + stackrestore, + &[stackptr.try_as_basic_value().unwrap_left().into()], + "rpc.stackrestore", + ); + Some(result) + } else { + Some(slot.into()) + } +} + +pub fn rpc_codegen_callback() -> Arc { + Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| { + rpc_codegen_callback_fn(ctx, obj, fun, args, generator) + }))) +} diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 1d7c476b..085151e6 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -12,7 +12,7 @@ use inkwell::{ }; use nac3core::typecheck::typedef::{Unifier, TypeEnum}; use nac3parser::{ - ast::{self, Stmt, StrRef}, + ast::{self, ExprKind, Stmt, StmtKind, StrRef}, parser::{self, parse_program}, }; use pyo3::prelude::*; @@ -24,7 +24,10 @@ use nac3core::{ codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, codegen::irrt::load_irrt, symbol_resolver::SymbolResolver, - toplevel::{composer::{TopLevelComposer, ComposerConfig}, DefinitionId, GenCall, TopLevelDef}, + toplevel::{ + composer::{ComposerConfig, TopLevelComposer}, + DefinitionId, GenCall, TopLevelDef, + }, typecheck::typedef::{FunSignature, FuncArg}, typecheck::{type_inferencer::PrimitiveStore, typedef::Type}, }; @@ -32,7 +35,7 @@ use nac3core::{ use tempfile::{self, TempDir}; use crate::{ - codegen::ArtiqCodeGenerator, + codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, symbol_resolver::{InnerResolver, PythonHelper, Resolver}, }; @@ -61,6 +64,7 @@ pub struct PrimitivePythonId { tuple: u64, typevar: u64, none: u64, + exception: u64, generic_alias: (u64, u64), virtual_id: u64, } @@ -81,6 +85,7 @@ struct Nac3 { primitive_ids: PrimitivePythonId, working_directory: TempDir, top_levels: Vec, + string_store: Arc>>, } impl Nac3 { @@ -127,9 +132,13 @@ impl Nac3 { let id_fn = PyModule::import(py, "builtins")?.getattr("id")?; match &base.node { ast::ExprKind::Name { id, .. } => { - let base_obj = module.getattr(py, id.to_string())?; - let base_id = id_fn.call1((base_obj,))?.extract()?; - Ok(registered_class_ids.contains(&base_id)) + if *id == "Exception".into() { + Ok(true) + } else { + let base_obj = module.getattr(py, id.to_string())?; + let base_id = id_fn.call1((base_obj,))?.extract()?; + Ok(registered_class_ids.contains(&base_id)) + } } _ => Ok(true), } @@ -143,7 +152,9 @@ impl Nac3 { { decorator_list.iter().any(|decorator| { if let ast::ExprKind::Name { id, .. } = decorator.node { - id.to_string() == "kernel" || id.to_string() == "portable" + id.to_string() == "kernel" + || id.to_string() == "portable" + || id.to_string() == "rpc" } else { false } @@ -159,7 +170,7 @@ impl Nac3 { } => decorator_list.iter().any(|decorator| { if let ast::ExprKind::Name { id, .. } = decorator.node { let id = id.to_string(); - id == "extern" || id == "portable" || id == "kernel" + id == "extern" || id == "portable" || id == "kernel" || id == "rpc" } else { false } @@ -188,7 +199,7 @@ impl Nac3 { Ok(ty) => ty, Err(e) => return Some(format!("type error inside object launching kernel: {}", e)) }; - + let fun_ty = if method_name.is_empty() { base_ty } else if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(base_ty) { @@ -201,7 +212,7 @@ impl Nac3 { } else { return Some("cannot launch kernel by calling a non-callable".into()) }; - + if let TypeEnum::TFunc(sig) = &*unifier.get_ty(fun_ty) { let FunSignature { args, .. } = &*sig.borrow(); if arg_names.len() > args.len() { @@ -269,7 +280,7 @@ impl Nac3 { ret: primitive.int64, vars: HashMap::new(), }, - Arc::new(GenCall::new(Box::new(move |ctx, _, _, _| { + Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { Some(time_fns.emit_now_mu(ctx)) }))), ), @@ -284,8 +295,9 @@ impl Nac3 { ret: primitive.none, vars: HashMap::new(), }, - Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| { - time_fns.emit_at_mu(ctx, args[0].1); + Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + time_fns.emit_at_mu(ctx, arg); None }))), ), @@ -300,16 +312,20 @@ impl Nac3 { ret: primitive.none, vars: HashMap::new(), }, - Arc::new(GenCall::new(Box::new(move |ctx, _, _, args| { - time_fns.emit_delay_mu(ctx, args[0].1); + Arc::new(GenCall::new(Box::new(move |ctx, _, _, args, generator| { + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator); + time_fns.emit_delay_mu(ctx, arg); None }))), ), ]; - let (_, builtins_def, builtins_ty) = TopLevelComposer::new(builtins.clone(), ComposerConfig { - kernel_ann: Some("Kernel"), - kernel_invariant_ann: "KernelInvariant" - }); + let (_, builtins_def, builtins_ty) = TopLevelComposer::new( + builtins.clone(), + ComposerConfig { + kernel_ann: Some("Kernel"), + kernel_invariant_ann: "KernelInvariant", + }, + ); let builtins_mod = PyModule::import(py, "builtins").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap(); @@ -385,6 +401,11 @@ impl Nac3 { .unwrap() .extract() .unwrap(), + exception: id_fn + .call1((builtins_mod.getattr("tuple").unwrap(),)) + .unwrap() + .extract() + .unwrap(), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); @@ -405,6 +426,7 @@ impl Nac3 { top_levels: Default::default(), pyid_to_def: Default::default(), working_directory, + string_store: Default::default() }) } @@ -441,6 +463,7 @@ impl Nac3 { method_name: &str, args: Vec<&PyAny>, filename: &str, + embedding_map: &PyAny, py: Python, ) -> PyResult<()> { let (mut composer, _, _) = TopLevelComposer::new(self.builtins.clone(), ComposerConfig { @@ -451,17 +474,26 @@ impl Nac3 { let builtins = PyModule::import(py, "builtins")?; let typings = PyModule::import(py, "typing")?; let id_fn = builtins.getattr("id")?; + let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); + let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); + let store_fun = embedding_map + .getattr("store_function") + .unwrap() + .to_object(py); let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py), type_fn: builtins.getattr("type").unwrap().to_object(py), origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), + store_obj, + store_str }; let mut module_to_resolver_cache: HashMap = HashMap::new(); let pyid_to_type = Arc::new(RwLock::new(HashMap::::new())); let global_value_ids = Arc::new(RwLock::new(HashSet::::new())); + let mut rpc_ids = vec![]; for (stmt, path, module) in self.top_levels.iter() { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; @@ -492,6 +524,7 @@ impl Nac3 { id_to_primitive: Default::default(), field_to_val: Default::default(), helper, + string_store: self.string_store.clone(), }))) as Arc; let name_to_pyid = Rc::new(name_to_pyid); @@ -502,7 +535,30 @@ impl Nac3 { let (name, def_id, ty) = composer .register_top_level(stmt.clone(), Some(resolver.clone()), path.clone()) - .map_err(|e| exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)))?; + .map_err(|e| { + exceptions::PyRuntimeError::new_err(format!("nac3 compilation failure: {}", e)) + })?; + + match &stmt.node { + StmtKind::FunctionDef { decorator_list, .. } => { + if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { + store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string()).unwrap())).unwrap(); + rpc_ids.push((None, def_id)); + } + } + StmtKind::ClassDef { name, body, .. } => { + let class_obj = module.getattr(py, name.to_string()).unwrap(); + for stmt in body.iter() { + if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { + if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { + rpc_ids.push((Some((class_obj.clone(), *name)), def_id)); + } + } + } + } + _ => () + } + let id = *name_to_pyid.get(&name).unwrap(); self.pyid_to_def.write().insert(id, def_id); { @@ -552,6 +608,7 @@ impl Nac3 { name_to_pyid, module: module.to_object(py), helper, + string_store: self.string_store.clone(), }))) as Arc; let (_, def_id, _) = composer .register_top_level( @@ -595,6 +652,45 @@ impl Nac3 { } } let top_level = Arc::new(composer.make_top_level_context()); + + { + let rpc_codegen = rpc_codegen_callback(); + let defs = top_level.definitions.read(); + for (class_data, id) in rpc_ids.iter() { + let mut def = defs[id.0].write(); + match &mut *def { + TopLevelDef::Function { + codegen_callback, .. + } => { + *codegen_callback = Some(rpc_codegen.clone()); + } + TopLevelDef::Class { methods, .. } => { + let (class_def, method_name) = class_data.as_ref().unwrap(); + for (name, _, id) in methods.iter() { + if name != method_name { + continue; + } + if let TopLevelDef::Function { + codegen_callback, .. + } = &mut *defs[id.0].write() + { + *codegen_callback = Some(rpc_codegen.clone()); + store_fun + .call1( + py, + ( + id.0.into_py(py), + class_def.getattr(py, name.to_string()).unwrap(), + ), + ) + .unwrap(); + } + } + } + } + } + } + let instance = { let defs = top_level.definitions.read(); let mut definition = defs[def_id.0].write(); @@ -634,15 +730,17 @@ impl Nac3 { let buffer = buffer.as_slice().into(); membuffer.lock().push(buffer); }))); - let size_t = if self.isa == Isa::Host { - 64 - } else { - 32 - }; + let size_t = if self.isa == Isa::Host { 64 } else { 32 }; let thread_names: Vec = (0..4).map(|_| "main".to_string()).collect(); let threads: Vec<_> = thread_names .iter() - .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) + .map(|s| { + Box::new(ArtiqCodeGenerator::new( + s.to_string(), + size_t, + self.time_fns, + )) + }) .collect(); py.allow_threads(|| { @@ -759,11 +857,12 @@ impl Nac3 { obj: &PyAny, method_name: &str, args: Vec<&PyAny>, + embedding_map: &PyAny, py: Python, ) -> PyResult { let filename_path = self.working_directory.path().join("module.elf"); let filename = filename_path.to_str().unwrap(); - self.compile_method_to_file(obj, method_name, args, filename, py)?; + self.compile_method_to_file(obj, method_name, args, filename, embedding_map, py)?; Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) } } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 33608116..36427631 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -42,6 +42,7 @@ pub struct InnerResolver { pub pyid_to_type: Arc>>, pub primitive_ids: PrimitivePythonId, pub helper: PythonHelper, + pub string_store: Arc>>, // module specific pub name_to_pyid: HashMap, pub module: PyObject, @@ -56,11 +57,14 @@ pub struct PythonHelper { pub id_fn: PyObject, pub origin_ty_fn: PyObject, pub args_ty_fn: PyObject, + pub store_obj: PyObject, + pub store_str: PyObject, } struct PythonValue { id: u64, value: PyObject, + store_obj: PyObject, resolver: Arc, } @@ -69,6 +73,36 @@ impl StaticValue for PythonValue { self.id } + fn get_const_obj<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + _: &mut dyn CodeGenerator, + ) -> BasicValueEnum<'ctx> { + ctx.module + .get_global(self.id.to_string().as_str()) + .map(|val| val.as_pointer_value().into()) + .unwrap_or_else(|| { + Python::with_gil(|py| -> PyResult> { + let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; + let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); + let global = + ctx.module + .add_global(struct_type, None, format!("{}_const", self.id).as_str()); + global.set_constant(true); + global.set_initializer(&ctx.ctx.const_struct( + &[ctx.ctx.i32_type().const_int(id as u64, false).into()], + false, + )); + let global2 = + ctx.module + .add_global(struct_type.ptr_type(AddressSpace::Generic), None, format!("{}_const2", self.id).as_str()); + global2.set_initializer(&global.as_pointer_value()); + Ok(global2.as_pointer_value().into()) + }) + .unwrap() + }) + } + fn to_basic_value_enum<'ctx, 'a>( &self, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -140,6 +174,7 @@ impl StaticValue for PythonValue { ValueEnum::Static(Arc::new(PythonValue { id, value: obj, + store_obj: self.store_obj.clone(), resolver: self.resolver.clone(), })) }) @@ -208,7 +243,9 @@ impl InnerResolver { Ok(Ok((primitives.bool, true))) } else if ty_id == self.primitive_ids.float { Ok(Ok((primitives.float, true))) - } else if ty_id == self.primitive_ids.list { + } else if ty_id == self.primitive_ids.exception { + Ok(Ok((primitives.exception, true))) + }else if ty_id == self.primitive_ids.list { // do not handle type var param and concrete check here let var = unifier.get_fresh_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var }); @@ -755,9 +792,7 @@ impl InnerResolver { .get_llvm_type(generator, ty) .into_pointer_type() .get_element_type() - .into_struct_type() - .as_basic_type_enum(); - + .into_struct_type(); { if self.global_value_ids.read().contains(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { @@ -783,7 +818,7 @@ impl InnerResolver { .collect(); let values = values?; if let Some(values) = values { - let val = ctx.ctx.const_struct(&values, false); + let val = ty.const_named_struct(&values); let global = ctx .module .add_global(ty, Some(AddressSpace::Generic), &id_str); @@ -948,6 +983,7 @@ impl SymbolResolver for Resolver { ValueEnum::Static(Arc::new(PythonValue { id, value: v, + store_obj: self.0.helper.store_obj.clone(), resolver: self.0.clone(), })) }) @@ -971,4 +1007,17 @@ impl SymbolResolver for Resolver { result }) } + + fn get_string_id(&self, s: &str) -> i32 { + let mut string_store = self.0.string_store.write(); + if let Some(id) = string_store.get(s) { + *id + } else { + let id = Python::with_gil(|py| -> PyResult { + self.0.helper.store_str.call1(py, (s, ))?.extract(py) + }).unwrap(); + string_store.insert(s.into(), id); + id + } + } } diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 6509cc40..8d537bc6 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -32,10 +32,16 @@ pub enum SymbolValue { pub trait StaticValue { fn get_unique_identifier(&self) -> u64; + fn get_const_obj<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + generator: &mut dyn CodeGenerator, + ) -> BasicValueEnum<'ctx>; + fn to_basic_value_enum<'ctx, 'a>( &self, ctx: &mut CodeGenContext<'ctx, 'a>, - generator: &mut dyn CodeGenerator + generator: &mut dyn CodeGenerator, ) -> BasicValueEnum<'ctx>; fn get_field<'ctx, 'a>(