This commit is contained in:
David Mak 2024-08-13 17:01:12 +08:00
parent 33792bea0d
commit 5b1995dc0b
1 changed files with 34 additions and 7 deletions

View File

@ -422,7 +422,10 @@ fn gen_rpc_tag(
} else {
unreachable!()
};
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
assert!(
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
"Only NDArrays of sizes between 0 and 255 can be RPCed"
);
buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8);
@ -517,22 +520,46 @@ fn rpc_codegen_callback_fn<'ctx>(
.0
.args
.iter()
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty))
.collect::<Result<Vec<_>, _>>()?;
.map(|arg| {
mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj) = obj.1 {
real_params.insert(0, obj.get_const_obj(ctx, generator));
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, arg) in real_params.iter().enumerate() {
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot =
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
let arg_slot = ctx
.builder
.build_bitcast(arg_slot, ptr_type, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let arg_slot = if matches!(&*ctx.unifier.get_ty_immutable(*arg_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
{
unsafe {
ctx.builder
.build_in_bounds_gep(
arg_slot,
&[size_type.const_int(u64::from(size_type.get_bit_width() / 8), false)], // should be 4
"",
)
.unwrap()
}
} else {
arg_slot
};
let arg_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,