diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py index dc64d7df..b25ed325 100644 --- a/nac3artiq/demo/embedding_map.py +++ b/nac3artiq/demo/embedding_map.py @@ -5,6 +5,7 @@ class EmbeddingMap: self.string_map = {} self.string_reverse_map = {} self.function_map = {} + self.attributes_writeback = [] # preallocate exception names self.preallocate_runtime_exception_names(["RuntimeError", diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index ffc93674..3a31171a 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -6,7 +6,7 @@ use nac3core::{ }, symbol_resolver::ValueEnum, toplevel::{DefinitionId, GenCall}, - typecheck::typedef::{FunSignature, Type}, + typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum} }; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; @@ -15,7 +15,9 @@ use inkwell::{ context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace, }; -use crate::timeline::TimeFns; +use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}}; + +use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use std::{ collections::hash_map::DefaultHasher, @@ -270,8 +272,6 @@ fn gen_rpc_tag<'ctx, 'a>( buffer.push(b'l'); gen_rpc_tag(ctx, *ty, buffer)?; } - // we should return an error, this will be fixed after improving error message - // as this requires returning an error during codegen _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } } @@ -291,7 +291,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( let int32 = ctx.ctx.i32_type(); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); - let service_id = int32.const_int(fun.1 .0 as u64, false); + let service_id = int32.const_int(fun.1.0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); if obj.is_some() { @@ -486,6 +486,81 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( Ok(Some(result)) } +pub fn attributes_writeback<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + generator: &mut dyn CodeGenerator, + inner_resolver: &InnerResolver, + host_attributes: PyObject, +) -> Result<(), String> { + Python::with_gil(|py| -> PyResult> { + let host_attributes = host_attributes.cast_as::(py)?; + let top_levels = ctx.top_level.definitions.read(); + let globals = inner_resolver.global_value_ids.read(); + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let mut values = Vec::new(); + let mut scratch_buffer = Vec::new(); + for (_, val) in globals.iter() { + let val = val.as_ref(py); + let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?; + if let Err(ty) = ty { + return Ok(Err(ty)) + } + let ty = ty.unwrap(); + match &*ctx.unifier.get_ty(ty) { + TypeEnum::TObj { fields, .. } => { + // we only care about primitive attributes + // for non-primitive attributes, they should be in another global + let mut attributes = Vec::new(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap(); + for (name, (field_ty, is_mutable)) in fields.iter() { + if !is_mutable { + continue + } + if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + attributes.push(name.to_string()); + let index = ctx.get_attr_index(ty, *name); + values.push((*field_ty, ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)]))); + } + } + if !attributes.is_empty() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + pydict.set_item("fields", attributes)?; + host_attributes.append(pydict)?; + } + }, + TypeEnum::TList { ty: elem_ty } => { + if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + host_attributes.append(pydict)?; + values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap())); + } + }, + _ => {} + } + } + let fun = FunSignature { + args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg { + name: i.to_string().into(), + ty: *ty, + default_value: None + }).collect(), + ret: ctx.primitives.none, + vars: Default::default() + }; + let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); + if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, generator) { + return Ok(Err(e)); + } + Ok(Ok(())) + }).unwrap()?; + Ok(()) +} + pub fn rpc_codegen_callback() -> Arc { Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| { rpc_codegen_callback_fn(ctx, obj, fun, args, generator) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 16827588..57e91652 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -10,6 +10,7 @@ use inkwell::{ targets::*, OptimizationLevel, }; +use nac3core::codegen::gen_func_impl; use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::typecheck::typedef::{TypeEnum, Unifier}; use nac3parser::{ @@ -36,6 +37,7 @@ use nac3core::{ use tempfile::{self, TempDir}; +use crate::codegen::attributes_writeback; use crate::{ codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore}, @@ -476,6 +478,8 @@ impl Nac3 { let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py); let store_str = embedding_map.getattr("store_str").unwrap().to_object(py); let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py); + let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py); + let global_value_ids: Arc>> = Arc::new(RwLock::new(HashMap::new())); let helper = PythonHelper { id_fn: builtins.getattr("id").unwrap().to_object(py), len_fn: builtins.getattr("len").unwrap().to_object(py), @@ -503,7 +507,6 @@ impl Nac3 { let mut module_to_resolver_cache: HashMap = HashMap::new(); - let global_value_ids = Arc::new(RwLock::new(HashSet::::new())); let mut rpc_ids = vec![]; for (stmt, path, module) in self.top_levels.iter() { let py_module: &PyAny = module.extract(py)?; @@ -617,7 +620,7 @@ impl Nac3 { }; let mut synthesized = parse_program(&synthesized, "__nac3_synthesized_modinit__".to_string().into()).unwrap(); - let resolver = Arc::new(Resolver(Arc::new(InnerResolver { + let inner_resolver = Arc::new(InnerResolver { id_to_type: builtins_ty.clone().into(), id_to_def: builtins_def.clone().into(), pyid_to_def: self.pyid_to_def.clone(), @@ -634,17 +637,18 @@ impl Nac3 { string_store: self.string_store.clone(), exception_ids: self.exception_ids.clone(), deferred_eval_store: self.deferred_eval_store.clone(), - }))) as Arc; + }); + let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc; let (_, def_id, _) = composer .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into()) .unwrap(); - let signature = + let fun_signature = FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() }; let mut store = ConcreteTypeStore::new(); let mut cache = HashMap::new(); let signature = - store.from_signature(&mut composer.unifier, &self.primitive, &signature, &mut cache); + store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); let signature = store.add_cty(signature); if let Err(e) = composer.start_analysis(true) { @@ -721,12 +725,29 @@ impl Nac3 { symbol_name: "__modinit__".to_string(), body: instance.body, signature, - resolver, + resolver: resolver.clone(), store, unifier_index: instance.unifier_id, calls: instance.calls, id: 0, }; + + let mut store = ConcreteTypeStore::new(); + let mut cache = HashMap::new(); + let signature = + store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); + let signature = store.add_cty(signature); + let attributes_writeback_task = CodeGenTask { + subst: Default::default(), + symbol_name: "attributes_writeback".to_string(), + body: Arc::new(Default::default()), + signature, + resolver, + store, + unifier_index: instance.unifier_id, + calls: Arc::new(Default::default()), + id: 0, + }; let isa = self.isa; let working_directory = self.working_directory.path().to_owned(); @@ -746,14 +767,27 @@ impl Nac3 { .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns))) .collect(); + let membuffer = membuffers.clone(); py.allow_threads(|| { let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), f); registry.add_task(task); registry.wait_tasks_complete(handles); + + let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); + let context = inkwell::context::Context::create(); + let module = context.create_module("attributes_writeback"); + let builder = context.create_builder(); + let (_, module, _) = gen_func_impl(&context, &mut generator, ®istry, builder, module, + attributes_writeback_task, |generator, ctx| { + attributes_writeback(ctx, generator, inner_resolver.as_ref(), host_attributes) + }).unwrap(); + let buffer = module.write_bitcode_to_memory(); + let buffer = buffer.as_slice().into(); + membuffer.lock().push(buffer); }); - let buffers = membuffers.lock(); let context = inkwell::context::Context::create(); + let buffers = membuffers.lock(); let main = context .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .unwrap(); @@ -765,6 +799,11 @@ impl Nac3 { main.link_in_module(other) .map_err(|err| CompileError::new_err(err.to_string()))?; } + let builder = context.create_builder(); + let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap(); + builder.position_before(&modinit_return); + builder.build_call(main.get_function("attributes_writeback").unwrap(), &[], "attributes_writeback"); + main.link_in_module(load_irrt(&context)) .map_err(|err| CompileError::new_err(err.to_string()))?; diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e177b1ff..b2aa1ddf 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -15,7 +15,7 @@ use pyo3::{ PyAny, PyObject, PyResult, Python, }; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{ Arc, atomic::{AtomicBool, Ordering::Relaxed} @@ -54,7 +54,7 @@ pub struct InnerResolver { pub id_to_pyval: RwLock>, pub id_to_primitive: RwLock>, pub field_to_val: RwLock>>, - pub global_value_ids: Arc>>, + pub global_value_ids: Arc>>, pub class_names: Mutex>, pub pyid_to_def: Arc>>, pub pyid_to_type: Arc>>, @@ -503,7 +503,7 @@ impl InnerResolver { } } - fn get_obj_type( + pub fn get_obj_type( &self, py: Python, obj: &PyAny, @@ -605,7 +605,7 @@ impl InnerResolver { unreachable!("must be tobj") } } - + let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { Ok(t) => t, Err(e) => { @@ -686,7 +686,7 @@ impl InnerResolver { } } - fn get_obj_value<'ctx, 'a>( + pub fn get_obj_value<'ctx, 'a>( &self, py: Python, obj: &PyAny, @@ -754,13 +754,13 @@ impl InnerResolver { .struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false); { - if self.global_value_ids.read().contains(&id) { + if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - self.global_value_ids.write().insert(id); + self.global_value_ids.write().insert(id, obj.into()); } } @@ -834,13 +834,13 @@ impl InnerResolver { let ty = ctx.ctx.struct_type(&types, false); { - if self.global_value_ids.read().contains(&id) { + if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - self.global_value_ids.write().insert(id); + self.global_value_ids.write().insert(id, obj.into()); } } @@ -869,13 +869,13 @@ impl InnerResolver { Some(v) => { let global_str = format!("{}_option", id); { - if self.global_value_ids.read().contains(&id) { + if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - self.global_value_ids.write().insert(id); + self.global_value_ids.write().insert(id, obj.into()); } } let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str); @@ -902,13 +902,13 @@ impl InnerResolver { .get_element_type() .into_struct_type(); { - if self.global_value_ids.read().contains(&id) { + if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str) }); return Ok(Some(global.as_pointer_value().into())); } else { - self.global_value_ids.write().insert(id); + self.global_value_ids.write().insert(id, obj.into()); } } // should be classes diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 6066a969..6c3a09b9 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -355,9 +355,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } } - pub fn gen_string>( + pub fn gen_string>( &mut self, - generator: &mut G, + generator: &mut dyn CodeGenerator, s: S, ) -> BasicValueEnum<'ctx> { self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index bce050c2..7e0b0f92 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -360,13 +360,14 @@ fn need_sret<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> bool { need_sret_impl(ctx, ty, true) } -pub fn gen_func<'ctx, G: CodeGenerator>( +pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> ( context: &'ctx Context, generator: &mut G, registry: &WorkerRegistry, builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, + codegen_function: F ) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { let top_level_ctx = registry.top_level_ctx.clone(); let static_value_store = registry.static_value_store.clone(); @@ -572,25 +573,34 @@ pub fn gen_func<'ctx, G: CodeGenerator>( need_sret: has_sret }; - let mut err = None; - for stmt in task.body.iter() { - if let Err(e) = generator.gen_stmt(&mut code_gen_context, stmt) { - err = Some(e); - break; - } - if code_gen_context.is_terminated() { - break; - } - } + let result = codegen_function(generator, &mut code_gen_context); + // after static analysis, only void functions can have no return at the end. if !code_gen_context.is_terminated() { code_gen_context.builder.build_return(None); } let CodeGenContext { builder, module, .. } = code_gen_context; - if let Some(e) = err { + if let Err(e) = result { return Err((builder, e)); } Ok((builder, module, fn_val)) } + +pub fn gen_func<'ctx, G: CodeGenerator>( + context: &'ctx Context, + generator: &mut G, + registry: &WorkerRegistry, + builder: Builder<'ctx>, + module: Module<'ctx>, + task: CodeGenTask, +) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { + let body = task.body.clone(); + gen_func_impl(context, generator, registry, builder, module, task, |generator, ctx| { + for stmt in body.iter() { + generator.gen_stmt(ctx, stmt)?; + } + Ok(()) + }) +}