From c3b7aa6386f174e844e21813631f5b9c82682ae4 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 14 Aug 2024 18:10:16 +0800 Subject: [PATCH] [artiq] Fix RPC of ndarrays from host --- nac3artiq/src/codegen.rs | 182 ++++++++++++++++++++------------ nac3core/src/codegen/classes.rs | 8 +- 2 files changed, 116 insertions(+), 74 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index cc3950be..4f16acf9 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -2,7 +2,7 @@ use nac3core::{ codegen::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, - NDArrayValue, RangeValue, UntypedArrayLikeAccessor, + NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor, }, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, @@ -40,7 +40,6 @@ use std::{ mem, sync::Arc, }; -use nac3core::codegen::classes::ProxyType; /// The parallelism mode within a block. #[derive(Copy, Clone, Eq, PartialEq)] @@ -487,13 +486,10 @@ fn format_rpc_arg<'ctx>( let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); - let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); - ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); - call_memcpy_generic( ctx, buffer.base_ptr(ctx, generator), - ppdata, + llvm_arg.ptr_to_data(ctx), llvm_pdata_sizeof, llvm_i1.const_zero(), ); @@ -548,6 +544,7 @@ fn format_rpc_ret<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_ppi8 = llvm_pi8.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) @@ -572,8 +569,20 @@ fn format_rpc_ret<'ctx>( let llvm_usize = generator.get_size_type(ctx.ctx); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); - let llvm_ret_ty = - NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); + let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); + + let ndims = + if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { + assert_eq!(values.len(), 1); + + u64::try_from(values[0].clone()).unwrap() + } else { + unreachable!(); + }; + ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); + ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); let llvm_usize_sizeof = ctx .builder @@ -582,93 +591,124 @@ fn format_rpc_ret<'ctx>( let llvm_pdata_sizeof = ctx .builder .build_int_truncate_or_bit_cast( - llvm_ret_ty.element_type().ptr_type(AddressSpace::default()).size_of(), + llvm_ret_ty.element_type().size_of().unwrap(), llvm_usize, "", ) .unwrap(); - let slot_addr = ctx.builder.build_alloca(llvm_ret_ty.as_base_type(), "rpc.ret.slot.addr").unwrap(); - let slot = ctx.builder.build_load(slot_addr, "rpc.ret.slot").map(BasicValueEnum::into_pointer_value).unwrap(); - let llvm_ndarray = NDArrayValue::from_ptr_val(slot, llvm_usize, None); - - let ndims = if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { - assert_eq!(values.len(), 1); - - u64::try_from(values[0].clone()).unwrap() - } else { - unreachable!(); - }; - llvm_ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); - llvm_ndarray.create_dim_sizes(ctx, llvm_usize, llvm_ndarray.load_ndims(ctx)); - let dims_buf_sz = - ctx.builder.build_int_mul(llvm_ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let buffer_size = ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); - let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.buffer").unwrap(); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.buffer")); - - // recv [*data, dim_sz[..]] - let alloc_size = ctx - .build_call_or_invoke(rpc_recv, &[buffer.base_ptr(ctx, generator).into()], "rpc.size.next") - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let ppdata = generator.gen_var_alloc(ctx, llvm_ret_ty.element_type(), None).unwrap(); - ctx.builder.build_store(ppdata, llvm_ndarray.data().base_ptr(ctx, generator)).unwrap(); - call_memcpy_generic( - ctx, - ppdata, - buffer.base_ptr(ctx, generator), - llvm_pdata_sizeof, - llvm_i1.const_zero(), - ); - - let pbuffer_dims_begin = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; - call_memcpy_generic( - ctx, - llvm_ndarray.dim_sizes().base_ptr(ctx, generator), - pbuffer_dims_begin, - dims_buf_sz, - llvm_i1.const_zero(), - ); - - let is_done = ctx + let buffer = + ctx.builder.build_array_alloca(llvm_pi8, buffer_size, "rpc.buffer").unwrap(); + let buffer = ctx .builder - .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") + .build_bitcast(buffer, llvm_pi8, "") + .map(BasicValueEnum::into_pointer_value) .unwrap(); + let buffer = ArraySliceValue::from_ptr_val( + buffer, + ctx.builder + .build_left_shift(buffer_size, llvm_usize.const_int(2, false), "") + .unwrap(), + Some("rpc.buffer.ptr"), + ); - let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); - let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").map(BasicValueEnum::into_pointer_value).unwrap(); - - ctx.builder.build_conditional_branch(is_done, tail_bb, head_bb).unwrap(); + let i_addr = ctx.builder.build_alloca(llvm_usize, "i.addr").unwrap(); + ctx.builder.build_store(i_addr, llvm_usize.const_zero()).unwrap(); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(head_bb); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); - phi.add_incoming(&[(&alloc_ptr, prehead_bb)]); + phi.add_incoming(&[(&buffer.base_ptr(ctx, generator), prehead_bb)]); let alloc_size = ctx .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") - .unwrap() - .into_int_value(); + .map(BasicValueEnum::into_int_value) + .unwrap(); + + // Parse metadata block for ndarrays + gen_if_callback( + generator, + ctx, + |_, ctx| { + let i = ctx + .builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, i, llvm_usize.const_zero(), "") + .unwrap()) + }, + |generator, ctx| { + let data_addr = phi.as_basic_value().into_pointer_value(); + let data_addr = ctx + .builder + .build_load( + ctx.builder + .build_bitcast(data_addr, llvm_ppi8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap(), + "", + ) + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + + call_memcpy_generic( + ctx, + ndarray.ptr_to_data(ctx), + data_addr, + llvm_pdata_sizeof, + llvm_i1.const_zero(), + ); + + let pbuffer_dims_begin = unsafe { + buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) + }; + call_memcpy_generic( + ctx, + ndarray.dim_sizes().base_ptr(ctx, generator), + pbuffer_dims_begin, + dims_buf_sz, + llvm_i1.const_zero(), + ); + + Ok(()) + }, + |_, _| Ok(()), + ) + .unwrap(); + let is_done = ctx .builder .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") .unwrap(); - ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.position_at_end(alloc_bb); - let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); - let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + + let alloc_ptr = + ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = ctx + .builder + .build_bitcast(alloc_ptr, llvm_pi8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap(); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + let i = + ctx.builder.build_load(i_addr, "i").map(BasicValueEnum::into_int_value).unwrap(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, false), "").unwrap(); + ctx.builder.build_store(i_addr, i).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); - - ctx.builder.build_load(slot_addr, "rpc.result").unwrap() + ndarray.as_base_value().into() } _ => { @@ -691,8 +731,10 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.position_at_end(alloc_bb); - let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); - let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + let alloc_ptr = + ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = + ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); ctx.builder.build_unconditional_branch(head_bb).unwrap(); @@ -842,7 +884,7 @@ fn rpc_codegen_callback_fn<'ctx>( let result = format_rpc_ret(generator, ctx, fun.0.ret); - if result.is_some_and(|res| !res.get_type().is_pointer_type()) { + if !result.is_some_and(|res| res.get_type().is_pointer_type()) { call_stackrestore(ctx, stackptr); } diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 52e9cca0..b6d3f2fa 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1330,7 +1330,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); @@ -1366,7 +1366,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); @@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. - fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); @@ -1420,7 +1420,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); }