diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index d5a53b164..8c7809b64 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -2,7 +2,7 @@ use nac3core::{ codegen::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, - NDArrayValue, RangeValue, UntypedArrayLikeAccessor, + NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor, }, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, @@ -22,7 +22,7 @@ use inkwell::{ module::Linkage, types::{BasicType, IntType}, values::{BasicValueEnum, PointerValue, StructValue}, - AddressSpace, IntPredicate, + AddressSpace, IntPredicate, OptimizationLevel, }; use pyo3::{ @@ -32,6 +32,7 @@ use pyo3::{ use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use inkwell::values::IntValue; use itertools::Itertools; use std::{ collections::{hash_map::DefaultHasher, HashMap}, @@ -486,13 +487,10 @@ fn format_rpc_arg<'ctx>( let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); - let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); - ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); - call_memcpy_generic( ctx, buffer.base_ptr(ctx, generator), - ppdata, + llvm_arg.ptr_to_data(ctx), llvm_pdata_sizeof, llvm_i1.const_zero(), ); @@ -528,6 +526,298 @@ fn format_rpc_arg<'ctx>( arg_slot } +/// Formats an RPC return value to conform to the expected format required by NAC3. +fn format_rpc_ret<'ctx>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + ret_ty: Type, +) -> Option> { + // -- receive value: + // T result = { + // void *ret_ptr = alloca(sizeof(T)); + // void *ptr = ret_ptr; + // loop: int size = rpc_recv(ptr); + // // Non-zero: Provide `size` bytes of extra storage for variable-length data. + // if(size) { ptr = alloca(size); goto loop; } + // else *(T*)ret_ptr + // } + + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + + let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { + ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) + }); + + if ctx.unifier.unioned(ret_ty, ctx.primitives.none) { + ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv"); + return None; + } + + let prehead_bb = ctx.builder.get_insert_block().unwrap(); + let current_function = prehead_bb.get_parent().unwrap(); + let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); + let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); + let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); + + let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty); + + let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + // Round `val` up to its modulo `power_of_two` + let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, + val: IntValue<'ctx>, + power_of_two: IntValue<'ctx>| { + debug_assert_eq!( + val.get_type().get_bit_width(), + power_of_two.get_type().get_bit_width() + ); + + let llvm_val_t = val.get_type(); + + let max_rem = ctx + .builder + .build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "") + .unwrap(); + ctx.builder + .build_and( + ctx.builder.build_int_add(val, max_rem, "").unwrap(), + ctx.builder.build_not(max_rem, "").unwrap(), + "", + ) + .unwrap() + }; + + // Setup types + let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); + + // Allocate the resulting ndarray + // A condition after format_rpc_ret ensures this will not be popped this off. + let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); + + // Setup ndims + let ndims = + if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { + assert_eq!(values.len(), 1); + + u64::try_from(values[0].clone()).unwrap() + } else { + unreachable!(); + }; + // Set `ndarray.ndims` + ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); + // Allocate `ndarray.shape` [size_t; ndims] + ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); + + /* + ndarray now: + - .ndims: initialized + - .shape: allocated but uninitialized .shape + - .data: uninitialized + */ + + let llvm_usize_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "") + .unwrap(); + let llvm_pdata_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast( + llvm_ret_ty.element_type().size_of().unwrap(), + llvm_usize, + "", + ) + .unwrap(); + let llvm_elem_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "") + .unwrap(); + + // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be + // (4 + 4 * ndims) bytes with 8-byte alignment + let sizeof_dims = + ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + let unaligned_buffer_size = + ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); + let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false)); + + let stackptr = call_stacksave(ctx, None); + // Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment + let buffer = ctx + .builder + .build_array_alloca( + llvm_i8_8, + ctx.builder + .build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "") + .unwrap(), + "rpc.buffer", + ) + .unwrap(); + let buffer = ctx + .builder + .build_bitcast(buffer, llvm_pi8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None); + + // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] + // + // The returned value is the number of bytes for `ndarray.data`. + let ndarray_nbytes = ctx + .build_call_or_invoke( + rpc_recv, + &[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. + "rpc.size.next", + ) + .map(BasicValueEnum::into_int_value) + .unwrap(); + + // debug_assert(ndarray_nbytes > 0) + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + ctx.make_assert( + generator, + ctx.builder + .build_int_compare( + IntPredicate::UGT, + ndarray_nbytes, + ndarray_nbytes.get_type().const_zero(), + "", + ) + .unwrap(), + "0:AssertionError", + "Unexpected RPC termination for ndarray - Expected data buffer next", + [None, None, None], + ctx.current_loc, + ); + } + + // Copy shape from the buffer to `ndarray.shape`. + let pbuffer_dims = + unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; + + call_memcpy_generic( + ctx, + ndarray.dim_sizes().base_ptr(ctx, generator), + pbuffer_dims, + sizeof_dims, + llvm_i1.const_zero(), + ); + // Restore stack from before allocation of buffer + call_stackrestore(ctx, stackptr); + + // Allocate `ndarray.data`. + // `ndarray.shape` must be initialized beforehand in this implementation + // (for ndarray.create_data() to know how many elements to allocate) + let num_elements = + call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); + + // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let sizeof_data = + ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap(); + + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::UGE, + sizeof_data, + ndarray_nbytes, + "", + ).unwrap(), + "0:AssertionError", + "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", + [Some(sizeof_data), Some(ndarray_nbytes), None], + ctx.current_loc, + ); + } + + ndarray.create_data(ctx, llvm_elem_ty, num_elements); + + let ndarray_data = ndarray.data().base_ptr(ctx, generator); + let ndarray_data_i8 = + ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); + + // NOTE: Currently on `prehead_bb` + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + // Inserting into `head_bb`. Do `rpc_recv` for `data` recursively. + ctx.builder.position_at_end(head_bb); + + let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); + phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); + + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + let is_done = ctx + .builder + .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") + .unwrap(); + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + + ctx.builder.position_at_end(alloc_bb); + // Align the allocation to sizeof(T) + let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); + let alloc_ptr = ctx + .builder + .build_array_alloca( + llvm_elem_ty, + ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), + "rpc.alloc", + ) + .unwrap(); + let alloc_ptr = + ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(tail_bb); + ndarray.as_base_value().into() + } + + _ => { + let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap(); + let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap(); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + ctx.builder.position_at_end(head_bb); + + let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); + phi.add_incoming(&[(&slotgen, prehead_bb)]); + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .unwrap() + .into_int_value(); + let is_done = ctx + .builder + .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") + .unwrap(); + + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + ctx.builder.position_at_end(alloc_bb); + + let alloc_ptr = + ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = + ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(tail_bb); + ctx.builder.build_load(slot, "rpc.result").unwrap() + } + }; + + Some(result) +} + fn rpc_codegen_callback_fn<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, @@ -663,63 +953,14 @@ fn rpc_codegen_callback_fn<'ctx>( // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); - // -- receive value: - // T result = { - // void *ret_ptr = alloca(sizeof(T)); - // void *ptr = ret_ptr; - // loop: int size = rpc_recv(ptr); - // // Non-zero: Provide `size` bytes of extra storage for variable-length data. - // if(size) { ptr = alloca(size); goto loop; } - // else *(T*)ret_ptr - // } - let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { - ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None) - }); + let result = format_rpc_ret(generator, ctx, fun.0.ret); - if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { - ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); - return Ok(None); - } - - let prehead_bb = ctx.builder.get_insert_block().unwrap(); - let current_function = prehead_bb.get_parent().unwrap(); - let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); - let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); - let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); - - let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret); - let need_load = !ret_ty.is_pointer_type(); - let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap(); - let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap(); - ctx.builder.build_unconditional_branch(head_bb).unwrap(); - ctx.builder.position_at_end(head_bb); - - let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap(); - phi.add_incoming(&[(&slotgen, prehead_bb)]); - let alloc_size = ctx - .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") - .unwrap() - .into_int_value(); - let is_done = ctx - .builder - .build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done") - .unwrap(); - - ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); - ctx.builder.position_at_end(alloc_bb); - - let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap(); - let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap(); - phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); - ctx.builder.build_unconditional_branch(head_bb).unwrap(); - - ctx.builder.position_at_end(tail_bb); - - let result = ctx.builder.build_load(slot, "rpc.result").unwrap(); - if need_load { + if !result.is_some_and(|res| res.get_type().is_pointer_type()) { + // An RPC returning an NDArray would not touch here. call_stackrestore(ctx, stackptr); } - Ok(Some(result)) + + Ok(result) } pub fn attributes_writeback( diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 52e9cca01..9ebac518e 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. - fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();