Compare commits

...

2 Commits

Author SHA1 Message Date
David Mak fa7062ec13 [artiq] WIP 2024-08-15 15:11:47 +08:00
David Mak 4fafe32563 [artiq] WIP 2024-08-15 15:11:08 +08:00
1 changed files with 45 additions and 26 deletions

View File

@ -32,6 +32,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use inkwell::values::IntValue;
use itertools::Itertools;
use std::{
collections::{hash_map::DefaultHasher, HashMap},
@ -544,7 +545,6 @@ fn format_rpc_ret<'ctx>(
let llvm_i8 = ctx.ctx.i8_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_ppi8 = llvm_pi8.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
@ -568,6 +568,24 @@ fn format_rpc_ret<'ctx>(
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntType<'ctx>| {
debug_assert!((power_of_two.get_bit_width() / 8).is_power_of_two());
let llvm_val_t = val.get_type();
let max_rem = (power_of_two.get_bit_width() / 8) - 1;
let max_rem = llvm_val_t.const_int(max_rem as u64, false);
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
@ -602,6 +620,7 @@ fn format_rpc_ret<'ctx>(
let buffer_size =
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
let buffer_size = round_up(ctx, buffer_size, ctx.ctx.i64_type());
let buffer =
ctx.builder.build_array_alloca(llvm_pi8, buffer_size, "rpc.buffer").unwrap();
@ -642,17 +661,17 @@ fn format_rpc_ret<'ctx>(
},
|generator, ctx| {
let phi = phi.as_basic_value().into_pointer_value();
let pbuffer_data_begin = unsafe {
ctx.builder.build_in_bounds_gep(phi, &[llvm_usize.const_int(8, false)], "")
}
.unwrap();
call_memcpy_generic(
ctx,
ndarray.ptr_to_data(ctx),
pbuffer_data_begin,
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
// let pbuffer_data_begin = unsafe {
// ctx.builder.build_in_bounds_gep(phi, &[llvm_usize.const_int(8, false)], "")
// }
// .unwrap();
// call_memcpy_generic(
// ctx,
// ndarray.ptr_to_data(ctx),
// pbuffer_data_begin,
// llvm_pdata_sizeof,
// llvm_i1.const_zero(),
// );
let pbuffer_dims_begin =
unsafe { ctx.builder.build_in_bounds_gep(phi, &[llvm_pdata_sizeof], "") }
@ -665,20 +684,20 @@ fn format_rpc_ret<'ctx>(
llvm_i1.const_zero(),
);
// // TODO: Testing for buffer
// ndarray.create_data(
// ctx,
// llvm_elem_ty,
// call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)),
// );
//
// call_memcpy_generic(
// ctx,
// ndarray.data().base_ptr(ctx, generator),
// buffer.base_ptr(ctx, generator),
// llvm_usize.const_int(8, false),
// llvm_i1.const_zero(),
// );
// TODO: Testing for buffer
ndarray.create_data(
ctx,
llvm_elem_ty,
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)),
);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
buffer.base_ptr(ctx, generator),
llvm_usize.const_int(8, false),
llvm_i1.const_zero(),
);
Ok(())
},