From 6030a60b1098304a42822b6aebd17c4ae3b61d3d Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 14 Aug 2024 15:53:27 +0800 Subject: [PATCH] WIP - [artiq] Fix RPC of ndarrays from host --- nac3artiq/src/codegen.rs | 234 +++++++++++++++++++++++++++++---------- 1 file changed, 178 insertions(+), 56 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 6a83d30a..56ca3b56 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -40,6 +40,7 @@ use std::{ mem, sync::Arc, }; +use nac3core::codegen::classes::ProxyType; /// The parallelism mode within a block. #[derive(Copy, Clone, Eq, PartialEq)] @@ -486,8 +487,7 @@ fn format_rpc_arg<'ctx>( let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); - let ppdata = - generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); + let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); call_memcpy_generic( @@ -529,6 +529,178 @@ fn format_rpc_arg<'ctx>( arg_slot } +/// Formats an RPC return value to conform to the expected format required by NAC3. +fn format_rpc_ret<'ctx>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + ret_ty: Type, +) -> Option> { + // -- receive value: + // T result = { + // void *ret_ptr = alloca(sizeof(T)); + // void *ptr = ret_ptr; + // loop: int size = rpc_recv(ptr); + // // Non-zero: Provide `size` bytes of extra storage for variable-length data. + // if(size) { ptr = alloca(size); goto loop; } + // else *(T*)ret_ptr + // } + + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + + let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { + ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) + }); + + if ctx.unifier.unioned(ret_ty, ctx.primitives.none) { + ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv"); + return None; + } + + let prehead_bb = ctx.builder.get_insert_block().unwrap(); + let current_function = prehead_bb.get_parent().unwrap(); + let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); + let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); + let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); + + let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty); + + let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); + let llvm_ret_ty = + NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); + + let llvm_usize_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast(llvm_ret_ty.size_type().size_of(), llvm_usize, "") + .unwrap(); + let llvm_pdata_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast( + llvm_ret_ty.element_type().ptr_type(AddressSpace::default()).size_of(), + llvm_usize, + "", + ) + .unwrap(); + + let slot_addr = ctx.builder.build_alloca(llvm_ret_ty.as_base_type(), "rpc.ret.slot.addr").unwrap(); + let slot = ctx.builder.build_load(slot_addr, "rpc.ret.slot").map(BasicValueEnum::into_pointer_value).unwrap(); + let llvm_ndarray = NDArrayValue::from_ptr_val(slot, llvm_usize, None); + + let ndims = if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { + assert_eq!(values.len(), 1); + + u64::try_from(values[0].clone()).unwrap() + } else { + unreachable!(); + }; + llvm_ndarray.create_dim_sizes(ctx, llvm_usize, llvm_usize.const_int(ndims, false)); + + let dims_buf_sz = + ctx.builder.build_int_mul(llvm_ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + + // TODO: This is either 4 + ndims * 4, or nelems * sizeof(T) + ndims * 4 + // let buffer_size = + // ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); + let buffer_size = llvm_usize.const_int(16, false); + + let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.buffer").unwrap(); + let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.buffer")); + + // recv [*data, dim_sz[..]] + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[buffer.base_ptr(ctx, generator).into()], "rpc.size.next") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + let ppdata = generator.gen_var_alloc(ctx, llvm_ret_ty.element_type(), None).unwrap(); + ctx.builder.build_store(ppdata, llvm_ndarray.data().base_ptr(ctx, generator)).unwrap(); + call_memcpy_generic( + ctx, + ppdata, + buffer.base_ptr(ctx, generator), + llvm_pdata_sizeof, + llvm_i1.const_zero(), + ); + + let pbuffer_dims_begin = + unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; + call_memcpy_generic( + ctx, + llvm_ndarray.dim_sizes().base_ptr(ctx, generator), + pbuffer_dims_begin, + dims_buf_sz, + llvm_i1.const_zero(), + ); + + let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").map(BasicValueEnum::into_pointer_value).unwrap(); + + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(head_bb); + let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); + phi.add_incoming(&[(&alloc_ptr, prehead_bb)]); + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .unwrap() + .into_int_value(); + let is_done = ctx + .builder + .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") + .unwrap(); + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + + ctx.builder.position_at_end(alloc_bb); + let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(tail_bb); + + ctx.builder.build_load(slot_addr, "rpc.result").unwrap() + } + + _ => { + let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap(); + let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap(); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + ctx.builder.position_at_end(head_bb); + + let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); + phi.add_incoming(&[(&slotgen, prehead_bb)]); + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .unwrap() + .into_int_value(); + let is_done = ctx + .builder + .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") + .unwrap(); + + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + ctx.builder.position_at_end(alloc_bb); + + let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(tail_bb); + + ctx.builder.build_load(slot, "rpc.result").unwrap() + } + }; + + Some(result) +} + fn rpc_codegen_callback_fn<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, @@ -664,63 +836,13 @@ fn rpc_codegen_callback_fn<'ctx>( // reclaim stack space used by arguments call_stackrestore(ctx, stackptr); - // -- receive value: - // T result = { - // void *ret_ptr = alloca(sizeof(T)); - // void *ptr = ret_ptr; - // loop: int size = rpc_recv(ptr); - // // Non-zero: Provide `size` bytes of extra storage for variable-length data. - // if(size) { ptr = alloca(size); goto loop; } - // else *(T*)ret_ptr - // } - let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { - ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None) - }); + let result = format_rpc_ret(generator, ctx, fun.0.ret); - if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { - ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); - return Ok(None); - } - - let prehead_bb = ctx.builder.get_insert_block().unwrap(); - let current_function = prehead_bb.get_parent().unwrap(); - let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head"); - let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue"); - let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail"); - - let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret); - let need_load = !ret_ty.is_pointer_type(); - let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap(); - let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap(); - ctx.builder.build_unconditional_branch(head_bb).unwrap(); - ctx.builder.position_at_end(head_bb); - - let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap(); - phi.add_incoming(&[(&slotgen, prehead_bb)]); - let alloc_size = ctx - .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") - .unwrap() - .into_int_value(); - let is_done = ctx - .builder - .build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done") - .unwrap(); - - ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); - ctx.builder.position_at_end(alloc_bb); - - let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap(); - let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap(); - phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); - ctx.builder.build_unconditional_branch(head_bb).unwrap(); - - ctx.builder.position_at_end(tail_bb); - - let result = ctx.builder.build_load(slot, "rpc.result").unwrap(); - if need_load { + if result.is_some_and(|res| !res.get_type().is_pointer_type()) { call_stackrestore(ctx, stackptr); } - Ok(Some(result)) + + Ok(result) } pub fn attributes_writeback(