forked from M-Labs/nac3
1
0
Fork 0

[artiq] Fix RPC of ndarrays to host

This commit is contained in:
David Mak 2024-08-13 17:01:12 +08:00 committed by David Mak
parent 69320a6cf1
commit b1c5c2e1d4
1 changed files with 114 additions and 16 deletions

View File

@ -1,9 +1,12 @@
use nac3core::{
codegen::{
classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
NDArrayValue, RangeValue, UntypedArrayLikeAccessor,
},
expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator,
},
@ -17,8 +20,8 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{
context::Context,
module::Linkage,
types::IntType,
values::{BasicValueEnum, StructValue},
types::{BasicType, IntType},
values::{BasicValueEnum, PointerValue, StructValue},
AddressSpace, IntPredicate,
};
@ -422,7 +425,10 @@ fn gen_rpc_tag(
} else {
unreachable!()
};
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
assert!(
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
"Only NDArrays of sizes between 0 and 255 can be RPCed"
);
buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8);
@ -434,6 +440,95 @@ fn gen_rpc_tag(
Ok(())
}
/// Formats an RPC argument to conform to the expected format required by `send_value`.
///
/// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format.
fn format_rpc_arg<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
(arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize),
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
// NAC3: NDArray = { usize, usize*, T* }
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let llvm_arg_ty =
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let dims_buf_sz =
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
let buffer_size =
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
let ppdata =
generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
call_memcpy_generic(
ctx,
buffer.base_ptr(ctx, generator),
ppdata,
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
call_memcpy_generic(
ctx,
pbuffer_dims_begin,
llvm_arg.dim_sizes().base_ptr(ctx, generator),
dims_buf_sz,
llvm_i1.const_zero(),
);
buffer.base_ptr(ctx, generator)
}
_ => {
let arg_slot = generator
.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")))
.unwrap();
ctx.builder.build_store(arg_slot, arg).unwrap();
ctx.builder
.build_bitcast(arg_slot, llvm_pi8, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
};
debug_assert_eq!(arg_slot.get_type(), llvm_pi8);
arg_slot
}
fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
@ -441,10 +536,10 @@ fn rpc_codegen_callback_fn<'ctx>(
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let size_type = generator.get_size_type(ctx.ctx);
let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type();
let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1 .0 as u64, false);
@ -517,22 +612,25 @@ fn rpc_codegen_callback_fn<'ctx>(
.0
.args
.iter()
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty))
.collect::<Result<Vec<_>, _>>()?;
.map(|arg| {
mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj) = obj.1 {
real_params.insert(0, obj.get_const_obj(ctx, generator));
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, arg) in real_params.iter().enumerate() {
let arg_slot =
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
let arg_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,