diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 7c43d885..0cc24e45 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -990,11 +990,12 @@ fn rpc_codegen_callback_fn<'ctx>( } } -pub fn attributes_writeback( - ctx: &mut CodeGenContext<'_, '_>, +pub fn attributes_writeback<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, inner_resolver: &InnerResolver, host_attributes: &PyObject, + return_obj: Option<(Type, ValueEnum<'ctx>)>, ) -> Result<(), String> { Python::with_gil(|py| -> PyResult> { let host_attributes: &PyList = host_attributes.downcast(py)?; @@ -1004,6 +1005,11 @@ pub fn attributes_writeback( let zero = int32.const_zero(); let mut values = Vec::new(); let mut scratch_buffer = Vec::new(); + + if let Some((ty, obj)) = return_obj { + values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap())); + } + for val in (*globals).values() { let val = val.as_ref(py); let ty = inner_resolver.get_obj_type( diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 6e80fd03..f3668194 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -37,7 +37,7 @@ use tempfile::{self, TempDir}; use nac3core::{ codegen::{ concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions, - CodeGenTargetMachineOptions, CodeGenTask, WithCall, WorkerRegistry, + CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry, }, inkwell::{ context::Context, @@ -673,33 +673,12 @@ impl Nac3 { let task = CodeGenTask { subst: Vec::default(), symbol_name: "__modinit__".to_string(), - body: instance.body, - signature, - resolver: resolver.clone(), - store, - unifier_index: instance.unifier_id, - calls: instance.calls, - id: 0, - }; - - let mut store = ConcreteTypeStore::new(); - let mut cache = HashMap::new(); - let signature = store.from_signature( - &mut composer.unifier, - &self.primitive, - &fun_signature, - &mut cache, - ); - let signature = store.add_cty(signature); - let attributes_writeback_task = CodeGenTask { - subst: Vec::default(), - symbol_name: "attributes_writeback".to_string(), body: Arc::new(Vec::default()), signature, resolver, store, unifier_index: instance.unifier_id, - calls: Arc::new(HashMap::default()), + calls: instance.calls, id: 0, }; @@ -723,16 +702,14 @@ impl Nac3 { .collect(); let membuffer = membuffers.clone(); + let mut has_return = false; py.allow_threads(|| { let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); - registry.add_task(task); - registry.wait_tasks_complete(handles); - let mut generator = - ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); + let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns); let context = Context::create(); - let module = context.create_module("attributes_writeback"); + let module = context.create_module("main"); let target_machine = self.llvm_options.create_target_machine().unwrap(); module.set_data_layout(&target_machine.get_target_data().get_data_layout()); module.set_triple(&target_machine.get_triple()); @@ -743,9 +720,27 @@ impl Nac3 { ®istry, builder, module, - attributes_writeback_task, + task, |generator, ctx| { - attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) + assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement"); + let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else { + unreachable!("toplevel statement must be an expression") + }; + let ExprKind::Call { .. } = expr.node else { + unreachable!("toplevel expression must be a function call") + }; + + let return_obj = + generator.gen_expr(ctx, &expr)?.map(|value| (expr.custom.unwrap(), value)); + has_return = return_obj.is_some(); + registry.wait_tasks_complete(handles); + attributes_writeback( + ctx, + generator, + inner_resolver.as_ref(), + &host_attributes, + return_obj, + ) }, ) .unwrap(); @@ -754,6 +749,8 @@ impl Nac3 { membuffer.lock().push(buffer); }); + embedding_map.setattr("expects_return", has_return).unwrap(); + // Link all modules into `main`. let buffers = membuffers.lock(); let main = context @@ -766,23 +763,6 @@ impl Nac3 { main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?; } - let builder = context.create_builder(); - let modinit_return = main - .get_function("__modinit__") - .unwrap() - .get_last_basic_block() - .unwrap() - .get_terminator() - .unwrap(); - builder.position_before(&modinit_return); - builder - .build_call( - main.get_function("attributes_writeback").unwrap(), - &[], - "attributes_writeback", - ) - .unwrap(); - main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?; let mut function_iter = main.get_first_function();