WIP
This commit is contained in:
parent
33792bea0d
commit
5b1995dc0b
|
@ -422,7 +422,10 @@ fn gen_rpc_tag(
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
|
assert!(
|
||||||
|
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
|
||||||
|
"Only NDArrays of sizes between 0 and 255 can be RPCed"
|
||||||
|
);
|
||||||
|
|
||||||
buffer.push(b'a');
|
buffer.push(b'a');
|
||||||
buffer.push((ndarray_ndims & 0xFF) as u8);
|
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||||
|
@ -517,22 +520,46 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||||
.0
|
.0
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty))
|
.map(|arg| {
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
mapping
|
||||||
|
.remove(&arg.name)
|
||||||
|
.unwrap()
|
||||||
|
.to_basic_value_enum(ctx, generator, arg.ty)
|
||||||
|
.map(|llvm_val| (llvm_val, arg.ty))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<(_, _)>, _>>()?;
|
||||||
if let Some(obj) = obj {
|
if let Some(obj) = obj {
|
||||||
if let ValueEnum::Static(obj) = obj.1 {
|
if let ValueEnum::Static(obj_val) = obj.1 {
|
||||||
real_params.insert(0, obj.get_const_obj(ctx, generator));
|
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
||||||
} else {
|
} else {
|
||||||
// should be an error here...
|
// should be an error here...
|
||||||
panic!("only host object is allowed");
|
panic!("only host object is allowed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i, arg) in real_params.iter().enumerate() {
|
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
||||||
let arg_slot =
|
let arg_slot =
|
||||||
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
|
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
|
||||||
ctx.builder.build_store(arg_slot, *arg).unwrap();
|
ctx.builder.build_store(arg_slot, *arg).unwrap();
|
||||||
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
|
let arg_slot = ctx
|
||||||
|
.builder
|
||||||
|
.build_bitcast(arg_slot, ptr_type, "rpc.arg")
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap();
|
||||||
|
let arg_slot = if matches!(&*ctx.unifier.get_ty_immutable(*arg_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
arg_slot,
|
||||||
|
&[size_type.const_int(u64::from(size_type.get_bit_width() / 8), false)], // should be 4
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
arg_slot
|
||||||
|
};
|
||||||
let arg_ptr = unsafe {
|
let arg_ptr = unsafe {
|
||||||
ctx.builder.build_gep(
|
ctx.builder.build_gep(
|
||||||
args_ptr,
|
args_ptr,
|
||||||
|
|
Loading…
Reference in New Issue