[artiq] Fix RPC of ndarrays from host
This commit is contained in:
parent
ae3b9bfd79
commit
c3b7aa6386
|
@ -2,7 +2,7 @@ use nac3core::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
|
||||||
NDArrayValue, RangeValue, UntypedArrayLikeAccessor,
|
NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
expr::{destructure_range, gen_call},
|
expr::{destructure_range, gen_call},
|
||||||
irrt::call_ndarray_calc_size,
|
irrt::call_ndarray_calc_size,
|
||||||
|
@ -40,7 +40,6 @@ use std::{
|
||||||
mem,
|
mem,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
use nac3core::codegen::classes::ProxyType;
|
|
||||||
|
|
||||||
/// The parallelism mode within a block.
|
/// The parallelism mode within a block.
|
||||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||||
|
@ -487,13 +486,10 @@ fn format_rpc_arg<'ctx>(
|
||||||
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
||||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
||||||
|
|
||||||
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
|
|
||||||
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
|
|
||||||
|
|
||||||
call_memcpy_generic(
|
call_memcpy_generic(
|
||||||
ctx,
|
ctx,
|
||||||
buffer.base_ptr(ctx, generator),
|
buffer.base_ptr(ctx, generator),
|
||||||
ppdata,
|
llvm_arg.ptr_to_data(ctx),
|
||||||
llvm_pdata_sizeof,
|
llvm_pdata_sizeof,
|
||||||
llvm_i1.const_zero(),
|
llvm_i1.const_zero(),
|
||||||
);
|
);
|
||||||
|
@ -548,6 +544,7 @@ fn format_rpc_ret<'ctx>(
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_ppi8 = llvm_pi8.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||||
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
||||||
|
@ -572,8 +569,20 @@ fn format_rpc_ret<'ctx>(
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||||
let llvm_ret_ty =
|
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
|
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||||
|
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
|
||||||
|
|
||||||
|
let ndims =
|
||||||
|
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
||||||
|
assert_eq!(values.len(), 1);
|
||||||
|
|
||||||
|
u64::try_from(values[0].clone()).unwrap()
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
||||||
|
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
||||||
|
|
||||||
let llvm_usize_sizeof = ctx
|
let llvm_usize_sizeof = ctx
|
||||||
.builder
|
.builder
|
||||||
|
@ -582,93 +591,124 @@ fn format_rpc_ret<'ctx>(
|
||||||
let llvm_pdata_sizeof = ctx
|
let llvm_pdata_sizeof = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_truncate_or_bit_cast(
|
.build_int_truncate_or_bit_cast(
|
||||||
llvm_ret_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
|
llvm_ret_ty.element_type().size_of().unwrap(),
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let slot_addr = ctx.builder.build_alloca(llvm_ret_ty.as_base_type(), "rpc.ret.slot.addr").unwrap();
|
|
||||||
let slot = ctx.builder.build_load(slot_addr, "rpc.ret.slot").map(BasicValueEnum::into_pointer_value).unwrap();
|
|
||||||
let llvm_ndarray = NDArrayValue::from_ptr_val(slot, llvm_usize, None);
|
|
||||||
|
|
||||||
let ndims = if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
|
||||||
assert_eq!(values.len(), 1);
|
|
||||||
|
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
|
||||||
} else {
|
|
||||||
unreachable!();
|
|
||||||
};
|
|
||||||
llvm_ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
|
||||||
llvm_ndarray.create_dim_sizes(ctx, llvm_usize, llvm_ndarray.load_ndims(ctx));
|
|
||||||
|
|
||||||
let dims_buf_sz =
|
let dims_buf_sz =
|
||||||
ctx.builder.build_int_mul(llvm_ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||||
|
|
||||||
let buffer_size =
|
let buffer_size =
|
||||||
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
||||||
|
|
||||||
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.buffer").unwrap();
|
let buffer =
|
||||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.buffer"));
|
ctx.builder.build_array_alloca(llvm_pi8, buffer_size, "rpc.buffer").unwrap();
|
||||||
|
let buffer = ctx
|
||||||
// recv [*data, dim_sz[..]]
|
|
||||||
let alloc_size = ctx
|
|
||||||
.build_call_or_invoke(rpc_recv, &[buffer.base_ptr(ctx, generator).into()], "rpc.size.next")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let ppdata = generator.gen_var_alloc(ctx, llvm_ret_ty.element_type(), None).unwrap();
|
|
||||||
ctx.builder.build_store(ppdata, llvm_ndarray.data().base_ptr(ctx, generator)).unwrap();
|
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
ppdata,
|
|
||||||
buffer.base_ptr(ctx, generator),
|
|
||||||
llvm_pdata_sizeof,
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let pbuffer_dims_begin =
|
|
||||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
llvm_ndarray.dim_sizes().base_ptr(ctx, generator),
|
|
||||||
pbuffer_dims_begin,
|
|
||||||
dims_buf_sz,
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let is_done = ctx
|
|
||||||
.builder
|
.builder
|
||||||
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
.build_bitcast(buffer, llvm_pi8, "")
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
let buffer = ArraySliceValue::from_ptr_val(
|
||||||
|
buffer,
|
||||||
|
ctx.builder
|
||||||
|
.build_left_shift(buffer_size, llvm_usize.const_int(2, false), "")
|
||||||
|
.unwrap(),
|
||||||
|
Some("rpc.buffer.ptr"),
|
||||||
|
);
|
||||||
|
|
||||||
let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
let i_addr = ctx.builder.build_alloca(llvm_usize, "i.addr").unwrap();
|
||||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").map(BasicValueEnum::into_pointer_value).unwrap();
|
ctx.builder.build_store(i_addr, llvm_usize.const_zero()).unwrap();
|
||||||
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
ctx.builder.build_conditional_branch(is_done, tail_bb, head_bb).unwrap();
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(head_bb);
|
ctx.builder.position_at_end(head_bb);
|
||||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||||
phi.add_incoming(&[(&alloc_ptr, prehead_bb)]);
|
phi.add_incoming(&[(&buffer.base_ptr(ctx, generator), prehead_bb)]);
|
||||||
let alloc_size = ctx
|
let alloc_size = ctx
|
||||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||||
.unwrap()
|
.map(BasicValueEnum::into_int_value)
|
||||||
.into_int_value();
|
.unwrap();
|
||||||
|
|
||||||
|
// Parse metadata block for ndarrays
|
||||||
|
gen_if_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
let i = ctx
|
||||||
|
.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::EQ, i, llvm_usize.const_zero(), "")
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|generator, ctx| {
|
||||||
|
let data_addr = phi.as_basic_value().into_pointer_value();
|
||||||
|
let data_addr = ctx
|
||||||
|
.builder
|
||||||
|
.build_load(
|
||||||
|
ctx.builder
|
||||||
|
.build_bitcast(data_addr, llvm_ppi8, "")
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.ptr_to_data(ctx),
|
||||||
|
data_addr,
|
||||||
|
llvm_pdata_sizeof,
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let pbuffer_dims_begin = unsafe {
|
||||||
|
buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None)
|
||||||
|
};
|
||||||
|
call_memcpy_generic(
|
||||||
|
ctx,
|
||||||
|
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||||
|
pbuffer_dims_begin,
|
||||||
|
dims_buf_sz,
|
||||||
|
llvm_i1.const_zero(),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, _| Ok(()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let is_done = ctx
|
let is_done = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
|
||||||
|
|
||||||
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||||
ctx.builder.position_at_end(alloc_bb);
|
ctx.builder.position_at_end(alloc_bb);
|
||||||
let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
|
||||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
let alloc_ptr =
|
||||||
|
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
||||||
|
let alloc_ptr = ctx
|
||||||
|
.builder
|
||||||
|
.build_bitcast(alloc_ptr, llvm_pi8, "")
|
||||||
|
.map(BasicValueEnum::into_pointer_value)
|
||||||
|
.unwrap();
|
||||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||||
|
let i =
|
||||||
|
ctx.builder.build_load(i_addr, "i").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
|
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, false), "").unwrap();
|
||||||
|
ctx.builder.build_store(i_addr, i).unwrap();
|
||||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(tail_bb);
|
ctx.builder.position_at_end(tail_bb);
|
||||||
|
ndarray.as_base_value().into()
|
||||||
ctx.builder.build_load(slot_addr, "rpc.result").unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -691,8 +731,10 @@ fn format_rpc_ret<'ctx>(
|
||||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||||
ctx.builder.position_at_end(alloc_bb);
|
ctx.builder.position_at_end(alloc_bb);
|
||||||
|
|
||||||
let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
let alloc_ptr =
|
||||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
||||||
|
let alloc_ptr =
|
||||||
|
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
||||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||||
|
|
||||||
|
@ -842,7 +884,7 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||||
|
|
||||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||||
|
|
||||||
if result.is_some_and(|res| !res.get_type().is_pointer_type()) {
|
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||||
call_stackrestore(ctx, stackptr);
|
call_stackrestore(ctx, stackptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1330,7 +1330,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
@ -1366,7 +1366,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
|
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||||
/// on the field.
|
/// on the field.
|
||||||
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
||||||
|
|
||||||
|
@ -1420,7 +1420,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the array of data elements `data` into this instance.
|
/// Stores the array of data elements `data` into this instance.
|
||||||
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
||||||
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue