[artiq] Fix RPC of ndarrays from host

This commit is contained in:
David Mak 2024-08-14 18:12:13 +08:00 committed by David Mak
parent 2b2cc06776
commit 99f7a8b4f3
2 changed files with 151 additions and 97 deletions

View File

@ -33,6 +33,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::classes::{ProxyType, ProxyValue};
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher}, hash::{Hash, Hasher},
@ -40,7 +41,6 @@ use std::{
mem, mem,
sync::Arc, sync::Arc,
}; };
use nac3core::codegen::classes::{ProxyType, TypedArrayLikeMutator};
/// The parallelism mode within a block. /// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
@ -487,13 +487,13 @@ fn format_rpc_arg<'ctx>(
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); // let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); // ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
buffer.base_ptr(ctx, generator), buffer.base_ptr(ctx, generator),
ppdata, llvm_arg.ptr_to_data(ctx),
llvm_pdata_sizeof, llvm_pdata_sizeof,
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
@ -548,6 +548,7 @@ fn format_rpc_ret<'ctx>(
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_ppi8 = llvm_pi8.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
@ -572,9 +573,22 @@ fn format_rpc_ret<'ctx>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_ret_ty = let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
let ndims =
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
} else {
unreachable!();
};
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
// TODO: Add alignment
let llvm_usize_sizeof = ctx let llvm_usize_sizeof = ctx
.builder .builder
.build_int_truncate_or_bit_cast(llvm_ret_ty.size_type().size_of(), llvm_usize, "") .build_int_truncate_or_bit_cast(llvm_ret_ty.size_type().size_of(), llvm_usize, "")
@ -582,101 +596,139 @@ fn format_rpc_ret<'ctx>(
let llvm_pdata_sizeof = ctx let llvm_pdata_sizeof = ctx
.builder .builder
.build_int_truncate_or_bit_cast( .build_int_truncate_or_bit_cast(
llvm_ret_ty.element_type().ptr_type(AddressSpace::default()).size_of(), llvm_ret_ty.element_type().size_of().unwrap(),
llvm_usize, llvm_usize,
"", "",
) )
.unwrap(); .unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
let slot_addr = ctx.builder.build_alloca(llvm_ret_ty.as_base_type(), "rpc.ret.slot.addr").unwrap(); let dims_buf_sz =
let slot = ctx.builder.build_load(slot_addr, "rpc.ret.slot").map(BasicValueEnum::into_pointer_value).unwrap(); ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
let llvm_ndarray = NDArrayValue::from_ptr_val(slot, llvm_usize, None);
let ndims = if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { let buffer_size =
assert_eq!(values.len(), 1); ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
u64::try_from(values[0].clone()).unwrap() let buffer =
} else { ctx.builder.build_array_alloca(llvm_pi8, buffer_size, "rpc.buffer").unwrap();
unreachable!(); let buffer = ctx
}; .builder
llvm_ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); .build_bitcast(buffer, llvm_pi8, "")
llvm_ndarray.create_dim_sizes(ctx, llvm_usize, llvm_ndarray.load_ndims(ctx)); .map(BasicValueEnum::into_pointer_value)
.unwrap();
let buffer = ArraySliceValue::from_ptr_val(
buffer,
ctx.builder
.build_left_shift(buffer_size, llvm_usize.const_int(2, false), "")
.unwrap(),
Some("rpc.buffer.ptr"),
);
unsafe { let i_addr = ctx.builder.build_alloca(llvm_usize, "i.addr").unwrap();
llvm_ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), llvm_usize.const_int(1, false)); ctx.builder.build_store(i_addr, llvm_usize.const_zero()).unwrap();
} ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.build_unconditional_branch(tail_bb).unwrap();
// let dims_buf_sz =
// ctx.builder.build_int_mul(llvm_ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
//
// let buffer_size =
// ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
//
// let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.buffer").unwrap();
// let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.buffer"));
//
// // recv [*data, dim_sz[..]]
// let alloc_size = ctx
// .build_call_or_invoke(rpc_recv, &[buffer.base_ptr(ctx, generator).into()], "rpc.size.next")
// .map(BasicValueEnum::into_int_value)
// .unwrap();
//
// let ppdata = generator.gen_var_alloc(ctx, llvm_ret_ty.element_type(), None).unwrap();
// ctx.builder.build_store(ppdata, llvm_ndarray.data().base_ptr(ctx, generator)).unwrap();
// call_memcpy_generic(
// ctx,
// ppdata,
// buffer.base_ptr(ctx, generator),
// llvm_pdata_sizeof,
// llvm_i1.const_zero(),
// );
//
// let pbuffer_dims_begin =
// unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
// call_memcpy_generic(
// ctx,
// llvm_ndarray.dim_sizes().base_ptr(ctx, generator),
// pbuffer_dims_begin,
// dims_buf_sz,
// llvm_i1.const_zero(),
// );
//
// let is_done = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
// .unwrap();
//
// let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
// let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").map(BasicValueEnum::into_pointer_value).unwrap();
//
// ctx.builder.build_conditional_branch(is_done, tail_bb, head_bb).unwrap();
//
ctx.builder.position_at_end(head_bb); ctx.builder.position_at_end(head_bb);
ctx.builder.build_unreachable().unwrap(); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
// let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); phi.add_incoming(&[(&buffer.base_ptr(ctx, generator), prehead_bb)]);
// phi.add_incoming(&[(&alloc_ptr, prehead_bb)]); let alloc_size = ctx
// let alloc_size = ctx .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
// .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .map(BasicValueEnum::into_int_value)
// .unwrap() .unwrap();
// .into_int_value();
// let is_done = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
// .unwrap();
// ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
//
ctx.builder.position_at_end(alloc_bb);
ctx.builder.build_unreachable().unwrap();
// let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
// let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
// phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
// ctx.builder.build_unconditional_branch(head_bb).unwrap();
//
ctx.builder.position_at_end(tail_bb);
ctx.builder.build_load(slot_addr, "rpc.result").unwrap() // Parse metadata block(s) for ndarrays
gen_if_callback(
generator,
ctx,
|_, ctx| {
let i = ctx
.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, i, llvm_usize.const_zero(), "")
// .build_int_compare(IntPredicate::ULT, i, llvm_usize.const_int(2, false), "")
.unwrap())
},
|generator, ctx| {
// let data_ptr = ctx
// .builder
// .build_bitcast(
// phi.as_basic_value().into_pointer_value(),
// llvm_ret_ty.element_type(),
// "",
// )
// .unwrap()
// .into_pointer_value();
// ndarray.store_data(ctx, data_ptr);
let data_addr = phi.as_basic_value().into_pointer_value();
let data_addr = ctx
.builder
.build_load(
ctx.builder
.build_bitcast(data_addr, llvm_ppi8, "")
.unwrap()
.into_pointer_value(),
"",
)
.unwrap()
.into_pointer_value();
call_memcpy_generic(
ctx,
ndarray.ptr_to_data(ctx),
data_addr,
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin = unsafe {
buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
pbuffer_dims_begin,
dims_buf_sz,
llvm_i1.const_zero(),
);
Ok(())
},
|_, _| Ok(()),
)
.unwrap();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr =
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr = ctx
.builder
.build_bitcast(alloc_ptr, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
let i =
ctx.builder.build_load(i_addr, "i").map(BasicValueEnum::into_int_value).unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, false), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ndarray.as_base_value().into()
} }
_ => { _ => {
@ -699,8 +751,10 @@ fn format_rpc_ret<'ctx>(
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb); ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap(); let alloc_ptr =
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap(); ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr =
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap();
@ -850,7 +904,7 @@ fn rpc_codegen_callback_fn<'ctx>(
let result = format_rpc_ret(generator, ctx, fun.0.ret); let result = format_rpc_ret(generator, ctx, fun.0.ret);
if result.is_some_and(|res| !res.get_type().is_pointer_type()) { if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
} }

View File

@ -1330,7 +1330,7 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
@ -1366,7 +1366,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
/// on the field. /// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
@ -1420,7 +1420,7 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Stores the array of data elements `data` into this instance. /// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
} }