From b1c5c2e1d4c7bc03a8622275be51a9cee7e04bc9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 13 Aug 2024 17:01:12 +0800 Subject: [PATCH] [artiq] Fix RPC of ndarrays to host --- nac3artiq/src/codegen.rs | 130 ++++++++++++++++++++++++++++++++++----- 1 file changed, 114 insertions(+), 16 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 925da41c1..6a83d30a1 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,9 +1,12 @@ use nac3core::{ codegen::{ - classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor}, + classes::{ + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, + NDArrayValue, RangeValue, UntypedArrayLikeAccessor, + }, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, - llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, + llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, CodeGenContext, CodeGenerator, }, @@ -17,8 +20,8 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use inkwell::{ context::Context, module::Linkage, - types::IntType, - values::{BasicValueEnum, StructValue}, + types::{BasicType, IntType}, + values::{BasicValueEnum, PointerValue, StructValue}, AddressSpace, IntPredicate, }; @@ -422,7 +425,10 @@ fn gen_rpc_tag( } else { unreachable!() }; - assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims)); + assert!( + (0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims), + "Only NDArrays of sizes between 0 and 255 can be RPCed" + ); buffer.push(b'a'); buffer.push((ndarray_ndims & 0xFF) as u8); @@ -434,6 +440,95 @@ fn gen_rpc_tag( Ok(()) } +/// Formats an RPC argument to conform to the expected format required by `send_value`. +/// +/// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format. +fn format_rpc_arg<'ctx>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + (arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize), +) -> PointerValue<'ctx> { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + + let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + // NAC3: NDArray = { usize, usize*, T* } + // libproto_artiq: NDArray = [data[..], dim_sz[..]] + + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); + let llvm_arg_ty = + NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); + let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + + let llvm_usize_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") + .unwrap(); + let llvm_pdata_sizeof = ctx + .builder + .build_int_truncate_or_bit_cast( + llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(), + llvm_usize, + "", + ) + .unwrap(); + + let dims_buf_sz = + ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + + let buffer_size = + ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); + + let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); + let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); + + let ppdata = + generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); + ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); + + call_memcpy_generic( + ctx, + buffer.base_ptr(ctx, generator), + ppdata, + llvm_pdata_sizeof, + llvm_i1.const_zero(), + ); + + let pbuffer_dims_begin = + unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; + call_memcpy_generic( + ctx, + pbuffer_dims_begin, + llvm_arg.dim_sizes().base_ptr(ctx, generator), + dims_buf_sz, + llvm_i1.const_zero(), + ); + + buffer.base_ptr(ctx, generator) + } + + _ => { + let arg_slot = generator + .gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}"))) + .unwrap(); + ctx.builder.build_store(arg_slot, arg).unwrap(); + + ctx.builder + .build_bitcast(arg_slot, llvm_pi8, "rpc.arg") + .map(BasicValueEnum::into_pointer_value) + .unwrap() + } + }; + + debug_assert_eq!(arg_slot.get_type(), llvm_pi8); + + arg_slot +} + fn rpc_codegen_callback_fn<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, obj: Option<(Type, ValueEnum<'ctx>)>, @@ -441,10 +536,10 @@ fn rpc_codegen_callback_fn<'ctx>( args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, ) -> Result>, String> { - let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let size_type = generator.get_size_type(ctx.ctx); let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); + let size_type = generator.get_size_type(ctx.ctx); + let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let service_id = int32.const_int(fun.1 .0 as u64, false); @@ -517,22 +612,25 @@ fn rpc_codegen_callback_fn<'ctx>( .0 .args .iter() - .map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty)) - .collect::, _>>()?; + .map(|arg| { + mapping + .remove(&arg.name) + .unwrap() + .to_basic_value_enum(ctx, generator, arg.ty) + .map(|llvm_val| (llvm_val, arg.ty)) + }) + .collect::, _>>()?; if let Some(obj) = obj { - if let ValueEnum::Static(obj) = obj.1 { - real_params.insert(0, obj.get_const_obj(ctx, generator)); + if let ValueEnum::Static(obj_val) = obj.1 { + real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0)); } else { // should be an error here... panic!("only host object is allowed"); } } - for (i, arg) in real_params.iter().enumerate() { - let arg_slot = - generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); - ctx.builder.build_store(arg_slot, *arg).unwrap(); - let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap(); + for (i, (arg, arg_ty)) in real_params.iter().enumerate() { + let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i)); let arg_ptr = unsafe { ctx.builder.build_gep( args_ptr,