use std::{ collections::{hash_map::DefaultHasher, HashMap}, hash::{Hash, Hasher}, iter::once, mem, sync::Arc, }; use itertools::Itertools; use pyo3::{ types::{PyDict, PyList}, PyObject, PyResult, Python, }; use nac3core::{ codegen::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor, }, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, CodeGenContext, CodeGenerator, }, inkwell::{ context::Context, module::Linkage, types::{BasicType, IntType}, values::{BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }, nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}, symbol_resolver::ValueEnum, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall}, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, }; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; /// The parallelism mode within a block. #[derive(Copy, Clone, Eq, PartialEq)] enum ParallelMode { /// No parallelism is currently registered for this context. None, /// Legacy (or shallow) parallelism. Default before NAC3. /// /// Each statement within the `with` block is treated as statements to be executed in parallel. Legacy, /// Deep parallelism. Default since NAC3. /// /// Each function call within the `with` block (except those within a nested `sequential` block) /// are treated to be executed in parallel. Deep, } pub struct ArtiqCodeGenerator<'a> { name: String, /// The size of a `size_t` variable in bits. size_t: u32, /// Monotonic counter for naming `start`/`stop` variables used by `with parallel` blocks. name_counter: u32, /// Variable for tracking the start of a `with parallel` block. start: Option>>, /// Variable for tracking the end of a `with parallel` block. end: Option>>, timeline: &'a (dyn TimeFns + Sync), /// The [`ParallelMode`] of the current parallel context. /// /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` /// statement, which is used to determine when and how the timeline should be updated. parallel_mode: ParallelMode, } impl<'a> ArtiqCodeGenerator<'a> { pub fn new( name: String, size_t: u32, timeline: &'a (dyn TimeFns + Sync), ) -> ArtiqCodeGenerator<'a> { assert!(size_t == 32 || size_t == 64); ArtiqCodeGenerator { name, size_t, name_counter: 0, start: None, end: None, timeline, parallel_mode: ParallelMode::None, } } /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the /// position of the timeline to the initial timeline position before entering the `parallel` /// block. /// /// Direct-`parallel` block context refers to when the generator is generating statements whose /// closest parent `with` statement is a `with parallel` block. fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> { if let Some(start) = self.start.clone() { let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum( ctx, self, start.custom.unwrap(), )?; self.timeline.emit_at_mu(ctx, start_val); } Ok(()) } /// If the generator is currently in a `parallel` block context, emits IR that updates the /// maximum end position of the `parallel` block as specified by the timeline `end` value. /// /// In general the `end` parameter should be set to `self.end` for updating the maximum end /// position for the current `parallel` block. Other values can be passed in to update the /// maximum end position for other `parallel` blocks. /// /// `parallel`-block context refers to when the generator is generating statements within a /// (possibly indirect) `parallel` block. /// /// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to /// the end of the provided value name. fn timeline_update_end_max( &mut self, ctx: &mut CodeGenContext<'_, '_>, end: Option>>, store_name: Option<&str>, ) -> Result<(), String> { if let Some(end) = end { let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum( ctx, self, end.custom.unwrap(), )?; let now = self.timeline.emit_now_mu(ctx); let max = call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax")); let end_store = self .gen_store_target( ctx, &end, store_name.map(|name| format!("{name}.addr")).as_deref(), )? .unwrap(); ctx.builder.build_store(end_store, max).unwrap(); } Ok(()) } } impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { fn get_name(&self) -> &str { &self.name } fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> { if self.size_t == 32 { ctx.i32_type() } else { ctx.i64_type() } } fn gen_block<'ctx, 'a, 'c, I: Iterator>>>( &mut self, ctx: &mut CodeGenContext<'ctx, 'a>, stmts: I, ) -> Result<(), String> where Self: Sized, { // Legacy parallel emits timeline end-update/timeline-reset after each top-level statement // in the parallel block if self.parallel_mode == ParallelMode::Legacy { for stmt in stmts { self.gen_stmt(ctx, stmt)?; if ctx.is_terminated() { break; } self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?; self.timeline_reset_start(ctx)?; } Ok(()) } else { gen_block(self, ctx, stmts) } } fn gen_call<'ctx>( &mut self, ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), params: Vec<(Option, ValueEnum<'ctx>)>, ) -> Result>, String> { let result = gen_call(self, ctx, obj, fun, params)?; // Deep parallel emits timeline end-update/timeline-reset after each function call if self.parallel_mode == ParallelMode::Deep { self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?; self.timeline_reset_start(ctx)?; } Ok(result) } fn gen_with( &mut self, ctx: &mut CodeGenContext<'_, '_>, stmt: &Stmt>, ) -> Result<(), String> { let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() }; if items.len() == 1 && items[0].optional_vars.is_none() { let item = &items[0]; // Behavior of parallel and sequential: // Each function call (indirectly, can be inside a sequential block) within a parallel // block will update the end variable to the maximum now_mu in the block. // Each function call directly inside a parallel block will reset the timeline after // execution. A parallel block within a sequential block (or not within any block) will // set the timeline to the max now_mu within the block (and the outer max now_mu will also // be updated). // // Implementation: We track the start and end separately. // - If there is a start variable, it indicates that we are directly inside a // parallel block and we have to reset the timeline after every function call. // - If there is a end variable, it indicates that we are (indirectly) inside a // parallel block, and we should update the max end value. if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node { if id == &"parallel".into() || id == &"legacy_parallel".into() { let old_start = self.start.take(); let old_end = self.end.take(); let old_parallel_mode = self.parallel_mode; let now = if let Some(old_start) = &old_start { self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum( ctx, self, old_start.custom.unwrap(), )? } else { self.timeline.emit_now_mu(ctx) }; // Emulate variable allocation, as we need to use the CodeGenContext // HashMap to store our variable due to lifetime limitation // Note: we should be able to store variables directly if generic // associative type is used by limiting the lifetime of CodeGenerator to // the LLVM Context. // The name is guaranteed to be unique as users cannot use this as variable // name. self.start = old_start.clone().map_or_else( || { let start = format!("with-{}-start", self.name_counter).into(); let start_expr = Located { // location does not matter at this point location: stmt.location, node: ExprKind::Name { id: start, ctx: *name_ctx }, custom: Some(ctx.primitives.int64), }; let start = self .gen_store_target(ctx, &start_expr, Some("start.addr"))? .unwrap(); ctx.builder.build_store(start, now).unwrap(); Ok(Some(start_expr)) as Result<_, String> }, |v| Ok(Some(v)), )?; let end = format!("with-{}-end", self.name_counter).into(); let end_expr = Located { // location does not matter at this point location: stmt.location, node: ExprKind::Name { id: end, ctx: *name_ctx }, custom: Some(ctx.primitives.int64), }; let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); ctx.builder.build_store(end, now).unwrap(); self.end = Some(end_expr); self.name_counter += 1; self.parallel_mode = match id.to_string().as_str() { "parallel" => ParallelMode::Deep, "legacy_parallel" => ParallelMode::Legacy, _ => unreachable!(), }; self.gen_block(ctx, body.iter())?; let current = ctx.builder.get_insert_block().unwrap(); // if the current block is terminated, move before the terminator // we want to set the timeline before reaching the terminator // TODO: This may be unsound if there are multiple exit paths in the // block... e.g. // if ...: // return // Perhaps we can fix this by using actual with block? let reset_position = if let Some(terminator) = current.get_terminator() { ctx.builder.position_before(&terminator); true } else { false }; // set duration let end_expr = self.end.take().unwrap(); let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum( ctx, self, end_expr.custom.unwrap(), )?; // inside a sequential block if old_start.is_none() { self.timeline.emit_at_mu(ctx, end_val); } // inside a parallel block, should update the outer max now_mu self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?; self.parallel_mode = old_parallel_mode; self.end = old_end; self.start = old_start; if reset_position { ctx.builder.position_at_end(current); } return Ok(()); } else if id == &"sequential".into() { // For deep parallel, temporarily take away start to avoid function calls in // the block from resetting the timeline. // This does not affect legacy parallel, as the timeline will be reset after // this block finishes execution. let start = self.start.take(); self.gen_block(ctx, body.iter())?; self.start = start; // Reset the timeline when we are exiting the sequential block // Legacy parallel does not need this, since it will be reset after codegen // for this statement is completed if self.parallel_mode == ParallelMode::Deep { self.timeline_reset_start(ctx)?; } return Ok(()); } } } // not parallel/sequential gen_with(self, ctx, stmt) } } fn gen_rpc_tag( ctx: &mut CodeGenContext<'_, '_>, ty: Type, buffer: &mut Vec, ) -> Result<(), String> { use nac3core::typecheck::typedef::TypeEnum::*; let int32 = ctx.primitives.int32; let int64 = ctx.primitives.int64; let float = ctx.primitives.float; let bool = ctx.primitives.bool; let str = ctx.primitives.str; let none = ctx.primitives.none; if ctx.unifier.unioned(ty, int32) { buffer.push(b'i'); } else if ctx.unifier.unioned(ty, int64) { buffer.push(b'I'); } else if ctx.unifier.unioned(ty, float) { buffer.push(b'f'); } else if ctx.unifier.unioned(ty, bool) { buffer.push(b'b'); } else if ctx.unifier.unioned(ty, str) { buffer.push(b's'); } else if ctx.unifier.unioned(ty, none) { buffer.push(b'n'); } else { let ty_enum = ctx.unifier.get_ty(ty); match &*ty_enum { TTuple { ty, is_vararg_ctx: false } => { buffer.push(b't'); buffer.push(ty.len() as u8); for ty in ty { gen_rpc_tag(ctx, *ty, buffer)?; } } TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let ty = iter_type_vars(params).next().unwrap().ty; buffer.push(b'l'); gen_rpc_tag(ctx, ty, buffer)?; } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray_ndims = if let TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims) { if values.len() != 1 { return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty))); } let value = values[0].clone(); u64::try_from(value.clone()) .map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))? } else { unreachable!() }; assert!( (0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims), "Only NDArrays of sizes between 0 and 255 can be RPCed" ); buffer.push(b'a'); buffer.push((ndarray_ndims & 0xFF) as u8); gen_rpc_tag(ctx, ndarray_dtype, buffer)?; } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } } Ok(()) } /// Formats an RPC argument to conform to the expected format required by `send_value`. /// /// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format. fn format_rpc_arg<'ctx>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, (arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize), ) -> PointerValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { // NAC3: NDArray = { usize, usize*, T* } // libproto_artiq: NDArray = [data[..], dim_sz[..]] let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); let llvm_usize_sizeof = ctx .builder .build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") .unwrap(); let llvm_pdata_sizeof = ctx .builder .build_int_truncate_or_bit_cast( llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(), llvm_usize, "", ) .unwrap(); let dims_buf_sz = ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let buffer_size = ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); call_memcpy_generic( ctx, buffer.base_ptr(ctx, generator), llvm_arg.ptr_to_data(ctx), llvm_pdata_sizeof, llvm_i1.const_zero(), ); let pbuffer_dims_begin = unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; call_memcpy_generic( ctx, pbuffer_dims_begin, llvm_arg.dim_sizes().base_ptr(ctx, generator), dims_buf_sz, llvm_i1.const_zero(), ); buffer.base_ptr(ctx, generator) } _ => { let arg_slot = generator .gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}"))) .unwrap(); ctx.builder.build_store(arg_slot, arg).unwrap(); ctx.builder .build_bit_cast(arg_slot, llvm_pi8, "rpc.arg") .map(BasicValueEnum::into_pointer_value) .unwrap() } }; debug_assert_eq!(arg_slot.get_type(), llvm_pi8); arg_slot } /// Formats an RPC return value to conform to the expected format required by NAC3. fn format_rpc_ret<'ctx>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, '_>, ret_ty: Type, ) -> Option> { // -- receive value: // T result = { // void *ret_ptr = alloca(sizeof(T)); // void *ptr = ret_ptr; // loop: int size = rpc_recv(ptr); // // Non-zero: Provide `size` bytes of extra storage for variable-length data. // if(size) { ptr = alloca(size); goto loop; } // else *(T*)ret_ptr // } let llvm_i8 = ctx.ctx.i8_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) }); if ctx.unifier.unioned(ret_ty, ctx.primitives.none) { ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv"); return None; } let prehead_bb = ctx.builder.get_insert_block().unwrap(); let current_function = prehead_bb.get_parent().unwrap(); let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty); let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); // Round `val` up to its modulo `power_of_two` let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, val: IntValue<'ctx>, power_of_two: IntValue<'ctx>| { debug_assert_eq!( val.get_type().get_bit_width(), power_of_two.get_type().get_bit_width() ); let llvm_val_t = val.get_type(); let max_rem = ctx .builder .build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "") .unwrap(); ctx.builder .build_and( ctx.builder.build_int_add(val, max_rem, "").unwrap(), ctx.builder.build_not(max_rem, "").unwrap(), "", ) .unwrap() }; // Setup types let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); // Allocate the resulting ndarray // A condition after format_rpc_ret ensures this will not be popped this off. let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); // Setup ndims let ndims = if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { assert_eq!(values.len(), 1); u64::try_from(values[0].clone()).unwrap() } else { unreachable!(); }; // Set `ndarray.ndims` ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); // Allocate `ndarray.shape` [size_t; ndims] ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); /* ndarray now: - .ndims: initialized - .shape: allocated but uninitialized .shape - .data: uninitialized */ let llvm_usize_sizeof = ctx .builder .build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "") .unwrap(); let llvm_pdata_sizeof = ctx .builder .build_int_truncate_or_bit_cast( llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(), llvm_usize, "", ) .unwrap(); let llvm_elem_sizeof = ctx .builder .build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "") .unwrap(); // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be // (4 + 4 * ndims) bytes with 8-byte alignment let sizeof_dims = ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let unaligned_buffer_size = ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false)); let stackptr = call_stacksave(ctx, None); // Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment let buffer = ctx .builder .build_array_alloca( llvm_i8_8, ctx.builder .build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "") .unwrap(), "rpc.buffer", ) .unwrap(); let buffer = ctx .builder .build_bit_cast(buffer, llvm_pi8, "") .map(BasicValueEnum::into_pointer_value) .unwrap(); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None); // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] // // The returned value is the number of bytes for `ndarray.data`. let ndarray_nbytes = ctx .build_call_or_invoke( rpc_recv, &[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. "rpc.size.next", ) .map(BasicValueEnum::into_int_value) .unwrap(); // debug_assert(ndarray_nbytes > 0) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { ctx.make_assert( generator, ctx.builder .build_int_compare( IntPredicate::UGT, ndarray_nbytes, ndarray_nbytes.get_type().const_zero(), "", ) .unwrap(), "0:AssertionError", "Unexpected RPC termination for ndarray - Expected data buffer next", [None, None, None], ctx.current_loc, ); } // Copy shape from the buffer to `ndarray.shape`. let pbuffer_dims = unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; call_memcpy_generic( ctx, ndarray.dim_sizes().base_ptr(ctx, generator), pbuffer_dims, sizeof_dims, llvm_i1.const_zero(), ); // Restore stack from before allocation of buffer call_stackrestore(ctx, stackptr); // Allocate `ndarray.data`. // `ndarray.shape` must be initialized beforehand in this implementation // (for ndarray.create_data() to know how many elements to allocate) let num_elements = call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let sizeof_data = ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap(); ctx.make_assert( generator, ctx.builder.build_int_compare(IntPredicate::UGE, sizeof_data, ndarray_nbytes, "", ).unwrap(), "0:AssertionError", "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", [Some(sizeof_data), Some(ndarray_nbytes), None], ctx.current_loc, ); } ndarray.create_data(ctx, llvm_elem_ty, num_elements); let ndarray_data = ndarray.data().base_ptr(ctx, generator); let ndarray_data_i8 = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); // NOTE: Currently on `prehead_bb` ctx.builder.build_unconditional_branch(head_bb).unwrap(); // Inserting into `head_bb`. Do `rpc_recv` for `data` recursively. ctx.builder.position_at_end(head_bb); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); let alloc_size = ctx .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .map(BasicValueEnum::into_int_value) .unwrap(); let is_done = ctx .builder .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") .unwrap(); ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.position_at_end(alloc_bb); // Align the allocation to sizeof(T) let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); let alloc_ptr = ctx .builder .build_array_alloca( llvm_elem_ty, ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), "rpc.alloc", ) .unwrap(); let alloc_ptr = ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); ndarray.as_base_value().into() } _ => { let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap(); let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(head_bb); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); phi.add_incoming(&[(&slotgen, prehead_bb)]); let alloc_size = ctx .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .unwrap() .into_int_value(); let is_done = ctx .builder .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") .unwrap(); ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.position_at_end(alloc_bb); let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); let alloc_ptr = ctx.builder.build_bit_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); ctx.builder.build_load(slot, "rpc.result").unwrap() } }; Some(result) } fn rpc_codegen_callback_fn<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, is_async: bool, ) -> Result>, String> { let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); let size_type = generator.get_size_type(ctx.ctx); let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let service_id = int32.const_int(fun.1 .0 as u64, false); // -- setup rpc tags let mut tag = Vec::new(); if obj.is_some() { tag.push(b'O'); } for arg in &fun.0.args { gen_rpc_tag(ctx, arg.ty, &mut tag)?; } tag.push(b':'); gen_rpc_tag(ctx, fun.0.ret, &mut tag)?; let mut hasher = DefaultHasher::new(); tag.hash(&mut hasher); let hash = format!("{}", hasher.finish()); let tag_ptr = ctx .module .get_global(hash.as_str()) .unwrap_or_else(|| { let tag_arr_ptr = ctx.module.add_global( int8.array_type(tag.len() as u32), None, format!("tagptr{}", fun.1 .0).as_str(), ); tag_arr_ptr.set_initializer(&int8.const_array( &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::>(), )); tag_arr_ptr.set_linkage(Linkage::Private); let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); tag_ptr.set_linkage(Linkage::Private); tag_ptr.set_initializer(&ctx.ctx.const_struct( &[ tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), size_type.const_int(tag.len() as u64, false).into(), ], false, )); tag_ptr }) .as_pointer_value(); let arg_length = args.len() + usize::from(obj.is_some()); let stackptr = call_stacksave(ctx, Some("rpc.stack")); let args_ptr = ctx .builder .build_array_alloca( ptr_type, ctx.ctx.i32_type().const_int(arg_length as u64, false), "argptr", ) .unwrap(); // -- rpc args handling let mut keys = fun.0.args.clone(); let mut mapping = HashMap::new(); for (key, value) in args { mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); } // default value handling for k in keys { mapping .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()); } // reorder the parameters let mut real_params = fun .0 .args .iter() .map(|arg| { mapping .remove(&arg.name) .unwrap() .to_basic_value_enum(ctx, generator, arg.ty) .map(|llvm_val| (llvm_val, arg.ty)) }) .collect::, _>>()?; if let Some(obj) = obj { if let ValueEnum::Static(obj_val) = obj.1 { real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0)); } else { // should be an error here... panic!("only host object is allowed"); } } for (i, (arg, arg_ty)) in real_params.iter().enumerate() { let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); let arg_ptr = unsafe { ctx.builder.build_gep( args_ptr, &[int32.const_int(i as u64, false)], &format!("rpc.arg{i}"), ) } .unwrap(); ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); } // call if is_async { let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { ctx.module.add_function( "rpc_send_async", ctx.ctx.void_type().fn_type( &[ int32.into(), tag_ptr_type.ptr_type(AddressSpace::default()).into(), ptr_type.ptr_type(AddressSpace::default()).into(), ], false, ), None, ) }); ctx.builder .build_call( rpc_send_async, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send", ) .unwrap(); } else { let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { ctx.module.add_function( "rpc_send", ctx.ctx.void_type().fn_type( &[ int32.into(), tag_ptr_type.ptr_type(AddressSpace::default()).into(), ptr_type.ptr_type(AddressSpace::default()).into(), ], false, ), None, ) }); ctx.builder .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") .unwrap(); } // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); if is_async { // async RPCs do not return any values Ok(None) } else { let result = format_rpc_ret(generator, ctx, fun.0.ret); if !result.is_some_and(|res| res.get_type().is_pointer_type()) { // An RPC returning an NDArray would not touch here. call_stackrestore(ctx, stackptr); } Ok(result) } } pub fn attributes_writeback( ctx: &mut CodeGenContext<'_, '_>, generator: &mut dyn CodeGenerator, inner_resolver: &InnerResolver, host_attributes: &PyObject, ) -> Result<(), String> { Python::with_gil(|py| -> PyResult> { let host_attributes: &PyList = host_attributes.downcast(py)?; let top_levels = ctx.top_level.definitions.read(); let globals = inner_resolver.global_value_ids.read(); let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); let mut values = Vec::new(); let mut scratch_buffer = Vec::new(); for val in (*globals).values() { let val = val.as_ref(py); let ty = inner_resolver.get_obj_type( py, val, &mut ctx.unifier, &top_levels, &ctx.primitives, )?; if let Err(ty) = ty { return Ok(Err(ty)); } let ty = ty.unwrap(); match &*ctx.unifier.get_ty(ty) { TypeEnum::TObj { fields, obj_id, .. } if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => { // we only care about primitive attributes // for non-primitive attributes, they should be in another global let mut attributes = Vec::new(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); for (name, (field_ty, is_mutable)) in fields { if !is_mutable { continue; } if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { attributes.push(name.to_string()); let (index, _) = ctx.get_attr_index(ty, *name); values.push(( *field_ty, ctx.build_gep_and_load( obj.into_pointer_value(), &[zero, int32.const_int(index as u64, false)], None, ), )); } } if !attributes.is_empty() { let pydict = PyDict::new(py); pydict.set_item("obj", val)?; pydict.set_item("fields", attributes)?; host_attributes.append(pydict)?; } } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let elem_ty = iter_type_vars(params).next().unwrap().ty; if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() { let pydict = PyDict::new(py); pydict.set_item("obj", val)?; host_attributes.append(pydict)?; values.push(( ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(), )); } } _ => {} } } let fun = FunSignature { args: values .iter() .enumerate() .map(|(i, (ty, _))| FuncArg { name: i.to_string().into(), ty: *ty, default_value: None, is_vararg: false, }) .collect(), ret: ctx.primitives.none, vars: VarMap::default(), }; let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, false) { return Ok(Err(e)); } Ok(Ok(())) }) .unwrap()?; Ok(()) } pub fn rpc_codegen_callback(is_async: bool) -> Arc { Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async) }))) } /// Returns the `fprintf` format constant for the given [`llvm_int_t`][`IntType`] on a platform with /// [`llvm_usize`] as its native word size. /// /// Note that, similar to format constants in ``, these constants need to be prepended /// with `%`. #[must_use] fn get_fprintf_format_constant<'ctx>( llvm_usize: IntType<'ctx>, llvm_int_t: IntType<'ctx>, is_unsigned: bool, ) -> String { debug_assert!(matches!(llvm_usize.get_bit_width(), 8 | 16 | 32 | 64)); let conv_spec = if is_unsigned { 'u' } else { 'd' }; // https://en.cppreference.com/w/c/language/arithmetic_types // Note that NAC3 does **not** support LP32 and LLP64 configurations match llvm_int_t.get_bit_width() { 8 => format!("hh{conv_spec}"), 16 => format!("h{conv_spec}"), 32 => conv_spec.to_string(), 64 => format!("{}{conv_spec}", if llvm_usize.get_bit_width() == 64 { "l" } else { "ll" }), _ => todo!( "Not yet implemented for i{} on {}-bit platform", llvm_int_t.get_bit_width(), llvm_usize.get_bit_width() ), } } /// Prints one or more `values` to `core_log` or `rtio_log`. /// /// * `separator` - The separator between multiple values. /// * `suffix` - String to terminate the printed string, if any. /// * `as_repr` - Whether the `repr()` output of values instead of `str()`. /// * `as_rtio` - Whether to print to `rtio_log` instead of `core_log`. fn polymorphic_print<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, values: &[(Type, ValueEnum<'ctx>)], separator: &str, suffix: Option<&str>, as_repr: bool, as_rtio: bool, ) -> Result<(), String> { let printf = |ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, fmt: String, args: Vec>| { debug_assert!(!fmt.is_empty()); debug_assert_eq!(fmt.as_bytes().last().unwrap(), &0u8); let fn_name = if as_rtio { "rtio_log" } else { "core_log" }; let print_fn = ctx.module.get_function(fn_name).unwrap_or_else(|| { let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let fn_t = if as_rtio { let llvm_void = ctx.ctx.void_type(); llvm_void.fn_type(&[llvm_pi8.into()], true) } else { let llvm_i32 = ctx.ctx.i32_type(); llvm_i32.fn_type(&[llvm_pi8.into()], true) }; ctx.module.add_function(fn_name, fn_t, None) }); let fmt = ctx.gen_string(generator, fmt); let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value(); ctx.builder .build_call( print_fn, &once(fmt.into()).chain(args).map(BasicValueEnum::into).collect_vec(), "", ) .unwrap(); }; let llvm_i32 = ctx.ctx.i32_type(); let llvm_i64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let suffix = suffix.unwrap_or_default(); let mut fmt = String::new(); let mut args = Vec::new(); let flush = |ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, fmt: &mut String, args: &mut Vec>| { if !fmt.is_empty() { fmt.push('\0'); printf(ctx, generator, mem::take(fmt), mem::take(args)); } }; for (ty, value) in values { let ty = *ty; let value = value.clone().to_basic_value_enum(ctx, generator, ty).unwrap(); if !fmt.is_empty() { fmt.push_str(separator); } match &*ctx.unifier.get_ty_immutable(ty) { TypeEnum::TTuple { ty: tys, is_vararg_ctx: false } => { let pvalue = { let pvalue = generator.gen_var_alloc(ctx, value.get_type(), None).unwrap(); ctx.builder.build_store(pvalue, value).unwrap(); pvalue }; fmt.push('('); flush(ctx, generator, &mut fmt, &mut args); let tuple_vals = tys .iter() .enumerate() .map(|(i, ty)| { (*ty, { let pfield = ctx.builder.build_struct_gep(pvalue, i as u32, "").unwrap(); ValueEnum::from(ctx.builder.build_load(pfield, "").unwrap()) }) }) .collect_vec(); polymorphic_print(ctx, generator, &tuple_vals, ", ", None, true, as_rtio)?; if tuple_vals.len() == 1 { fmt.push_str(",)"); } else { fmt.push(')'); } } TypeEnum::TFunc { .. } => todo!(), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::None.id() => { fmt.push_str("None"); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Bool.id() => { fmt.push_str("%.*s"); let true_str = ctx.gen_string(generator, "True"); let true_data = unsafe { true_str.get_field_at_index_unchecked(0) }.into_pointer_value(); let true_len = unsafe { true_str.get_field_at_index_unchecked(1) }.into_int_value(); let false_str = ctx.gen_string(generator, "False"); let false_data = unsafe { false_str.get_field_at_index_unchecked(0) }.into_pointer_value(); let false_len = unsafe { false_str.get_field_at_index_unchecked(1) }.into_int_value(); let bool_val = generator.bool_to_i1(ctx, value.into_int_value()); args.extend([ ctx.builder.build_select(bool_val, true_len, false_len, "").unwrap(), ctx.builder.build_select(bool_val, true_data, false_data, "").unwrap(), ]); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Int32.id() || *obj_id == PrimDef::Int64.id() || *obj_id == PrimDef::UInt32.id() || *obj_id == PrimDef::UInt64.id() => { let is_unsigned = *obj_id == PrimDef::UInt32.id() || *obj_id == PrimDef::UInt64.id(); let llvm_int_t = value.get_type().into_int_type(); debug_assert!(matches!(llvm_usize.get_bit_width(), 32 | 64)); debug_assert!(matches!(llvm_int_t.get_bit_width(), 32 | 64)); let fmt_spec = format!( "%{}", get_fprintf_format_constant(llvm_usize, llvm_int_t, is_unsigned) ); fmt.push_str(fmt_spec.as_str()); args.push(value); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Float.id() => { fmt.push_str("%g"); args.push(value); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Str.id() => { if as_repr { fmt.push_str("\"%.*s\""); } else { fmt.push_str("%.*s"); } let str = value.into_struct_value(); let str_data = unsafe { str.get_field_at_index_unchecked(0) }.into_pointer_value(); let str_len = unsafe { str.get_field_at_index_unchecked(1) }.into_int_value(); args.extend(&[str_len.into(), str_data.into()]); } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { let elem_ty = *params.iter().next().unwrap().1; fmt.push('['); flush(ctx, generator, &mut fmt, &mut args); let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); let len = val.load_size(ctx, None); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); gen_for_callback_incrementing( generator, ctx, None, llvm_usize.const_zero(), (len, false), |generator, ctx, _, i| { let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) }; polymorphic_print( ctx, generator, &[(elem_ty, elem.into())], "", None, true, as_rtio, )?; gen_if_callback( generator, ctx, |_, ctx| { Ok(ctx .builder .build_int_compare(IntPredicate::ULT, i, last, "") .unwrap()) }, |generator, ctx| { printf(ctx, generator, ", \0".into(), Vec::default()); Ok(()) }, |_, _| Ok(()), )?; Ok(()) }, llvm_usize.const_int(1, false), )?; fmt.push(']'); flush(ctx, generator, &mut fmt, &mut args); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); fmt.push_str("array(["); flush(ctx, generator, &mut fmt, &mut args); let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); gen_for_callback_incrementing( generator, ctx, None, llvm_usize.const_zero(), (len, false), |generator, ctx, _, i| { let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) }; polymorphic_print( ctx, generator, &[(elem_ty, elem.into())], "", None, true, as_rtio, )?; gen_if_callback( generator, ctx, |_, ctx| { Ok(ctx .builder .build_int_compare(IntPredicate::ULT, i, last, "") .unwrap()) }, |generator, ctx| { printf(ctx, generator, ", \0".into(), Vec::default()); Ok(()) }, |_, _| Ok(()), )?; Ok(()) }, llvm_usize.const_int(1, false), )?; fmt.push_str(")]"); flush(ctx, generator, &mut fmt, &mut args); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Range.id() => { fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); let val = RangeValue::from_ptr_val(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); polymorphic_print( ctx, generator, &[ (ctx.primitives.int32, start.into()), (ctx.primitives.int32, stop.into()), (ctx.primitives.int32, step.into()), ], ", ", None, false, as_rtio, )?; fmt.push(')'); } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::Exception.id() => { let fmt_str = format!( "%{}(%{}, %{1:}, %{1:})", get_fprintf_format_constant(llvm_usize, llvm_i32, false), get_fprintf_format_constant(llvm_usize, llvm_i64, false), ); let exn = value.into_pointer_value(); let name = ctx .build_in_bounds_gep_and_load( exn, &[llvm_i32.const_zero(), llvm_i32.const_zero()], None, ) .into_int_value(); let param0 = ctx .build_in_bounds_gep_and_load( exn, &[llvm_i32.const_zero(), llvm_i32.const_int(6, false)], None, ) .into_int_value(); let param1 = ctx .build_in_bounds_gep_and_load( exn, &[llvm_i32.const_zero(), llvm_i32.const_int(7, false)], None, ) .into_int_value(); let param2 = ctx .build_in_bounds_gep_and_load( exn, &[llvm_i32.const_zero(), llvm_i32.const_int(8, false)], None, ) .into_int_value(); fmt.push_str(fmt_str.as_str()); args.extend_from_slice(&[name.into(), param0.into(), param1.into(), param2.into()]); } _ => unreachable!( "Unsupported object type for polymorphic_print: {}", ctx.unifier.stringify(ty) ), } } fmt.push_str(suffix); flush(ctx, generator, &mut fmt, &mut args); Ok(()) } /// Invokes the `core_log` intrinsic function. pub fn call_core_log_impl<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, arg: (Type, BasicValueEnum<'ctx>), ) -> Result<(), String> { let (arg_ty, arg_val) = arg; polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\n"), false, false)?; Ok(()) } /// Invokes the `rtio_log` intrinsic function. pub fn call_rtio_log_impl<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut dyn CodeGenerator, channel: StructValue<'ctx>, arg: (Type, BasicValueEnum<'ctx>), ) -> Result<(), String> { let (arg_ty, arg_val) = arg; polymorphic_print( ctx, generator, &[(ctx.primitives.str, channel.into())], " ", Some("\x1E"), false, true, )?; polymorphic_print(ctx, generator, &[(arg_ty, arg_val.into())], " ", Some("\x1D"), false, true)?; Ok(()) } /// Generates a call to `core_log`. pub fn gen_core_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: &Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, ) -> Result<(), String> { assert!(obj.is_none()); assert_eq!(args.len(), 1); let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, value_ty)?; call_core_log_impl(ctx, generator, (value_ty, value_arg)) } /// Generates a call to `rtio_log`. pub fn gen_rtio_log<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: &Option<(Type, ValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), args: &[(Option, ValueEnum<'ctx>)], generator: &mut dyn CodeGenerator, ) -> Result<(), String> { assert!(obj.is_none()); assert_eq!(args.len(), 2); let channel_ty = fun.0.args[0].ty; assert!(ctx.unifier.unioned(channel_ty, ctx.primitives.str)); let channel_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, channel_ty)?.into_struct_value(); let value_ty = fun.0.args[1].ty; let value_arg = args[1].1.clone().to_basic_value_enum(ctx, generator, value_ty)?; call_rtio_log_impl(ctx, generator, channel_arg, (value_ty, value_arg)) }