From f680bee7a6c85eda03ead47bbffc0112ebc64779 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 14 Aug 2024 18:12:13 +0800 Subject: [PATCH] [artiq] Fix RPC of ndarrays from host to device --- nac3artiq/src/codegen.rs | 241 +++++++++++++++--------- nac3core/src/codegen/classes.rs | 8 +- nac3core/src/codegen/llvm_intrinsics.rs | 34 ++++ 3 files changed, 186 insertions(+), 97 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index b5f41159..6d2d9e42 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -33,6 +33,8 @@ use pyo3::{ use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use itertools::Itertools; +use nac3core::codegen::classes::{ProxyType, ProxyValue, TypedArrayLikeMutator}; +use nac3core::codegen::llvm_intrinsics::call_memset; use std::{ collections::{hash_map::DefaultHasher, HashMap}, hash::{Hash, Hasher}, @@ -40,7 +42,6 @@ use std::{ mem, sync::Arc, }; -use nac3core::codegen::classes::{ProxyType, TypedArrayLikeMutator}; /// The parallelism mode within a block. #[derive(Copy, Clone, Eq, PartialEq)] @@ -487,13 +488,13 @@ fn format_rpc_arg<'ctx>( let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); - let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); - ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); + // let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); + // ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); call_memcpy_generic( ctx, buffer.base_ptr(ctx, generator), - ppdata, + llvm_arg.ptr_to_data(ctx), llvm_pdata_sizeof, llvm_i1.const_zero(), ); @@ -548,6 +549,7 @@ fn format_rpc_ret<'ctx>( let llvm_i8 = ctx.ctx.i8_type(); let llvm_i32 = ctx.ctx.i32_type(); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_ppi8 = llvm_pi8.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) @@ -572,9 +574,22 @@ fn format_rpc_ret<'ctx>( let llvm_usize = generator.get_size_type(ctx.ctx); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); - let llvm_ret_ty = - NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); + let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); + let ndims = + if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { + assert_eq!(values.len(), 1); + + u64::try_from(values[0].clone()).unwrap() + } else { + unreachable!(); + }; + ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); + ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); + + // TODO: Add alignment let llvm_usize_sizeof = ctx .builder .build_int_truncate_or_bit_cast(llvm_ret_ty.size_type().size_of(), llvm_usize, "") @@ -582,101 +597,139 @@ fn format_rpc_ret<'ctx>( let llvm_pdata_sizeof = ctx .builder .build_int_truncate_or_bit_cast( - llvm_ret_ty.element_type().ptr_type(AddressSpace::default()).size_of(), + llvm_ret_ty.element_type().size_of().unwrap(), llvm_usize, "", ) .unwrap(); + let llvm_elem_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "") + .unwrap(); - let slot_addr = ctx.builder.build_alloca(llvm_ret_ty.as_base_type(), "rpc.ret.slot.addr").unwrap(); - let slot = ctx.builder.build_load(slot_addr, "rpc.ret.slot").map(BasicValueEnum::into_pointer_value).unwrap(); - let llvm_ndarray = NDArrayValue::from_ptr_val(slot, llvm_usize, None); + let dims_buf_sz = + ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); - let ndims = if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { - assert_eq!(values.len(), 1); + let buffer_size = + ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); - u64::try_from(values[0].clone()).unwrap() - } else { - unreachable!(); - }; - llvm_ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); - llvm_ndarray.create_dim_sizes(ctx, llvm_usize, llvm_ndarray.load_ndims(ctx)); + let buffer = + ctx.builder.build_array_alloca(llvm_pi8, buffer_size, "rpc.buffer").unwrap(); + let buffer = ctx + .builder + .build_bitcast(buffer, llvm_pi8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + let buffer = ArraySliceValue::from_ptr_val( + buffer, + ctx.builder + .build_left_shift(buffer_size, llvm_usize.const_int(2, false), "") + .unwrap(), + Some("rpc.buffer.ptr"), + ); - unsafe { - llvm_ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), llvm_usize.const_int(1, false)); - } + let i_addr = ctx.builder.build_alloca(llvm_usize, "i.addr").unwrap(); + ctx.builder.build_store(i_addr, llvm_usize.const_zero()).unwrap(); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); - ctx.builder.build_unconditional_branch(tail_bb).unwrap(); - - // let dims_buf_sz = - // ctx.builder.build_int_mul(llvm_ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); - // - // let buffer_size = - // ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); - // - // let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.buffer").unwrap(); - // let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.buffer")); - // - // // recv [*data, dim_sz[..]] - // let alloc_size = ctx - // .build_call_or_invoke(rpc_recv, &[buffer.base_ptr(ctx, generator).into()], "rpc.size.next") - // .map(BasicValueEnum::into_int_value) - // .unwrap(); - // - // let ppdata = generator.gen_var_alloc(ctx, llvm_ret_ty.element_type(), None).unwrap(); - // ctx.builder.build_store(ppdata, llvm_ndarray.data().base_ptr(ctx, generator)).unwrap(); - // call_memcpy_generic( - // ctx, - // ppdata, - // buffer.base_ptr(ctx, generator), - // llvm_pdata_sizeof, - // llvm_i1.const_zero(), - // ); - // - // let pbuffer_dims_begin = - // unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; - // call_memcpy_generic( - // ctx, - // llvm_ndarray.dim_sizes().base_ptr(ctx, generator), - // pbuffer_dims_begin, - // dims_buf_sz, - // llvm_i1.const_zero(), - // ); - // - // let is_done = ctx - // .builder - // .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") - // .unwrap(); - // - // let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); - // let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").map(BasicValueEnum::into_pointer_value).unwrap(); - // - // ctx.builder.build_conditional_branch(is_done, tail_bb, head_bb).unwrap(); - // ctx.builder.position_at_end(head_bb); - ctx.builder.build_unreachable().unwrap(); - // let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); - // phi.add_incoming(&[(&alloc_ptr, prehead_bb)]); - // let alloc_size = ctx - // .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") - // .unwrap() - // .into_int_value(); - // let is_done = ctx - // .builder - // .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") - // .unwrap(); - // ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); - // - ctx.builder.position_at_end(alloc_bb); - ctx.builder.build_unreachable().unwrap(); - // let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); - // let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); - // phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); - // ctx.builder.build_unconditional_branch(head_bb).unwrap(); - // - ctx.builder.position_at_end(tail_bb); + let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); + phi.add_incoming(&[(&buffer.base_ptr(ctx, generator), prehead_bb)]); + let alloc_size = ctx + .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") + .map(BasicValueEnum::into_int_value) + .unwrap(); - ctx.builder.build_load(slot_addr, "rpc.result").unwrap() + // Parse metadata block(s) for ndarrays + gen_if_callback( + generator, + ctx, + |_, ctx| { + let i = ctx + .builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + + Ok(ctx + .builder + .build_int_compare(IntPredicate::EQ, i, llvm_usize.const_zero(), "") + // .build_int_compare(IntPredicate::ULT, i, llvm_usize.const_int(2, false), "") + .unwrap()) + }, + |generator, ctx| { + // let data_ptr = ctx + // .builder + // .build_bitcast( + // phi.as_basic_value().into_pointer_value(), + // llvm_ret_ty.element_type(), + // "", + // ) + // .unwrap() + // .into_pointer_value(); + // ndarray.store_data(ctx, data_ptr); + let data_addr = phi.as_basic_value().into_pointer_value(); + let data_addr = ctx + .builder + .build_load( + ctx.builder + .build_bitcast(data_addr, llvm_ppi8, "") + .unwrap() + .into_pointer_value(), + "", + ) + .unwrap() + .into_pointer_value(); + + call_memcpy_generic( + ctx, + ndarray.ptr_to_data(ctx), + data_addr, + llvm_pdata_sizeof, + llvm_i1.const_zero(), + ); + + let pbuffer_dims_begin = unsafe { + buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) + }; + call_memcpy_generic( + ctx, + ndarray.dim_sizes().base_ptr(ctx, generator), + pbuffer_dims_begin, + dims_buf_sz, + llvm_i1.const_zero(), + ); + + Ok(()) + }, + |_, _| Ok(()), + ) + .unwrap(); + + let is_done = ctx + .builder + .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done") + .unwrap(); + + ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); + ctx.builder.position_at_end(alloc_bb); + + let alloc_ptr = + ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = ctx + .builder + .build_bitcast(alloc_ptr, llvm_pi8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); + let i = + ctx.builder.build_load(i_addr, "i").map(BasicValueEnum::into_int_value).unwrap(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, false), "").unwrap(); + ctx.builder.build_store(i_addr, i).unwrap(); + ctx.builder.build_unconditional_branch(head_bb).unwrap(); + + ctx.builder.position_at_end(tail_bb); + ndarray.as_base_value().into() } _ => { @@ -699,8 +752,10 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.position_at_end(alloc_bb); - let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); - let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); + let alloc_ptr = + ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); + let alloc_ptr = + ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); ctx.builder.build_unconditional_branch(head_bb).unwrap(); @@ -850,7 +905,7 @@ fn rpc_codegen_callback_fn<'ctx>( let result = format_rpc_ret(generator, ctx, fun.0.ret); - if result.is_some_and(|res| !res.get_type().is_pointer_type()) { + if !result.is_some_and(|res| res.get_type().is_pointer_type()) { call_stackrestore(ctx, stackptr); } diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 52e9cca0..b6d3f2fa 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1330,7 +1330,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); @@ -1366,7 +1366,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); @@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. - fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { let llvm_i32 = ctx.ctx.i32_type(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); @@ -1420,7 +1420,7 @@ impl<'ctx> NDArrayValue<'ctx> { } /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); } diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs index 6e878715..14eb4f74 100644 --- a/nac3core/src/codegen/llvm_intrinsics.rs +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -199,6 +199,40 @@ pub fn call_memcpy_generic<'ctx>( call_memcpy(ctx, dest, src, len, is_volatile); } +/// Invokes the [`llvm.memset`](https://llvm.org/docs/LangRef.html#llvm-memset-intrinsic) intrinsic. +pub fn call_memset<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dest: PointerValue<'ctx>, + val: IntValue<'ctx>, + len: IntValue<'ctx>, + is_volatile: IntValue<'ctx>, +) { + const FN_NAME: &str = "llvm.memset"; + + debug_assert!(dest.get_type().get_element_type().is_int_type()); + debug_assert_eq!(dest.get_type().get_element_type().into_int_type().get_bit_width(), 8); + debug_assert_eq!(val.get_type().get_bit_width(), 8); + debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64)); + debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1); + + let llvm_dest_t = dest.get_type(); + let llvm_src_t = val.get_type(); + let llvm_len_t = len.get_type(); + + let intrinsic_fn = Intrinsic::find(FN_NAME) + .and_then(|intrinsic| { + intrinsic.get_declaration( + &ctx.module, + &[llvm_dest_t.into(), llvm_len_t.into()], + ) + }) + .unwrap(); + + ctx.builder + .build_call(intrinsic_fn, &[dest.into(), val.into(), len.into(), is_volatile.into()], "") + .unwrap(); +} + /// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function) /// /// Arguments: