1
0
forked from M-Labs/nac3

artiq: reimplement reformat_rpc_arg for ndarray

This commit is contained in:
lyken 2024-08-22 13:05:03 +08:00 committed by David Mak
parent d44e226e43
commit 2e75d5a730

View File

@ -14,26 +14,29 @@ use pyo3::{
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{ classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
},
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size, irrt::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
model::*,
object::{any::AnyObject, ndarray::NDArrayObject},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
inkwell::{ inkwell::{
context::Context, context::Context,
module::Linkage, module::Linkage,
types::{BasicType, IntType}, types::IntType,
values::{BasicValueEnum, IntValue, PointerValue, StructValue}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}, },
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}, nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall}, toplevel::{
helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys,
DefinitionId, GenCall,
},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
}; };
@ -454,55 +457,42 @@ fn format_rpc_arg<'ctx>(
// NAC3: NDArray = { usize, usize*, T* } // NAC3: NDArray = { usize, usize*, T* }
// libproto_artiq: NDArray = [data[..], dim_sz[..]] // libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type(); let ndarray = AnyObject { ty: arg_ty, value: arg };
let llvm_usize = generator.get_size_type(ctx.ctx); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let dtype = ctx.get_llvm_type(generator, ndarray.dtype);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
.builder // the reader.
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") // Turning it into a ContiguousNDArray to get a `data` that is contiguous.
.unwrap(); let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype));
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let dims_buf_sz = let sizeof_sizet = Int(SizeT).size_of(generator, ctx.ctx);
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet);
let buffer_size = let sizeof_pdata = Ptr(Any(dtype)).size_of(generator, ctx.ctx);
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata);
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims);
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata);
call_memcpy_generic( // buf = { data: void*, shape: [size_t; ndims]; }
ctx, let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value);
buffer.base_ptr(ctx, generator), let buf_data = buf;
llvm_arg.ptr_to_data(ctx), let buf_shape = buf_data.offset(ctx, sizeof_pdata.value);
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin = // Write to `buf->data`
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr<Any>
call_memcpy_generic( let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte));
ctx, buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value);
pbuffer_dims_begin,
llvm_arg.dim_sizes().base_ptr(ctx, generator),
dims_buf_sz,
llvm_i1.const_zero(),
);
buffer.base_ptr(ctx, generator) // Write to `buf->shape`
let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte));
buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value);
buf.value
} }
_ => { _ => {
@ -563,8 +553,10 @@ fn format_rpc_ret<'ctx>(
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i1 = ctx.ctx.bool_type(); // FIXME: It is possible to rewrite everything more neatly with `Model<'ctx>`, but this is not too important.
let llvm_usize = generator.get_size_type(ctx.ctx);
let num_0 = Int(SizeT).const_0(generator, ctx.ctx);
let num_8 = Int(SizeT).const_int(generator, ctx.ctx, 8, false);
// Round `val` up to its modulo `power_of_two` // Round `val` up to its modulo `power_of_two`
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
@ -590,60 +582,36 @@ fn format_rpc_ret<'ctx>(
.unwrap() .unwrap()
}; };
// Setup types
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
// Allocate the resulting ndarray // Allocate the resulting ndarray
// A condition after format_rpc_ret ensures this will not be popped this off. // A condition after format_rpc_ret ensures this will not be popped this off.
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
// Setup ndims // NOTE: Current content of `ndarray`:
let ndims = // - * `data` - **NOT YET** allocated.
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { // - * `itemsize` - initialized to be size_of(dtype).
assert_eq!(values.len(), 1); // - * `ndims` - initialized.
// - * `shape` - allocated; has uninitialized values.
// - * `strides` - allocated; has uninitialized values.
u64::try_from(values[0].clone()).unwrap() let itemsize = ndarray.instance.get(generator, ctx, |f| f.itemsize); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`.
} else { let dtype_llvm = ctx.get_llvm_type(generator, dtype);
unreachable!();
};
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
- .ndims: initialized
- .shape: allocated but uninitialized .shape
- .data: uninitialized
*/
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
// (4 + 4 * ndims) bytes with 8-byte alignment // (4 + 4 * ndims) bytes with 8-byte alignment
let sizeof_dims = let sizeof_size_t = Int(SizeT).size_of(generator, ctx.ctx);
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let sizeof_size_t = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_size_t); // sizeof(size_t)
let unaligned_buffer_size =
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); let sizeof_ptr = Ptr(Int(Byte)).size_of(generator, ctx.ctx);
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false)); let sizeof_ptr = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_ptr); // sizeof(uint8_t*)
let sizeof_shape = ndarray.ndims_llvm(generator, ctx.ctx).mul(ctx, sizeof_size_t); // sizeof([size_t; ndims]); same as the # of bytes of `ndarray.shape`.
// Size of the buffer for the initial `rpc_recv()`.
let unaligned_buffer_size = sizeof_ptr.add(ctx, sizeof_shape); // sizeof(uint8_t*) + sizeof([size_t; ndims]).
let buffer_size = round_up(ctx, unaligned_buffer_size.value, num_8.value);
let buffer_size = unsafe { Int(SizeT).believe_value(buffer_size) };
let stackptr = call_stacksave(ctx, None); let stackptr = call_stacksave(ctx, None);
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment // Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
@ -651,9 +619,7 @@ fn format_rpc_ret<'ctx>(
.builder .builder
.build_array_alloca( .build_array_alloca(
llvm_i8_8, llvm_i8_8,
ctx.builder ctx.builder.build_int_unsigned_div(buffer_size.value, num_8.value, "").unwrap(),
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
.unwrap(),
"rpc.buffer", "rpc.buffer",
) )
.unwrap(); .unwrap();
@ -662,7 +628,7 @@ fn format_rpc_ret<'ctx>(
.build_bit_cast(buffer, llvm_pi8, "") .build_bit_cast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None); let buffer = unsafe { Ptr(Int(Byte)).believe_value(buffer) };
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
// //
@ -670,24 +636,20 @@ fn format_rpc_ret<'ctx>(
let ndarray_nbytes = ctx let ndarray_nbytes = ctx
.build_call_or_invoke( .build_call_or_invoke(
rpc_recv, rpc_recv,
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. &[buffer.value.into()], // Reads [usize; ndims]
"rpc.size.next", "rpc.size.next",
) )
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let ndarray_nbytes = unsafe { Int(SizeT).believe_value(ndarray_nbytes) };
// debug_assert(ndarray_nbytes > 0) // debug_assert(ndarray_nbytes > 0)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let cmp = ndarray_nbytes.compare(ctx, IntPredicate::UGT, num_0);
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder cmp.value,
.build_int_compare(
IntPredicate::UGT,
ndarray_nbytes,
ndarray_nbytes.get_type().const_zero(),
"",
)
.unwrap(),
"0:AssertionError", "0:AssertionError",
"Unexpected RPC termination for ndarray - Expected data buffer next", "Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None], [None, None, None],
@ -696,49 +658,39 @@ fn format_rpc_ret<'ctx>(
} }
// Copy shape from the buffer to `ndarray.shape`. // Copy shape from the buffer to `ndarray.shape`.
let pbuffer_dims = // We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`.
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let pbuffer_shape = buffer.offset(ctx, sizeof_ptr.value);
let pbuffer_shape = pbuffer_shape.pointer_cast(generator, ctx, Int(SizeT));
// Copy shape from buffer to `ndarray.shape`
ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape);
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
);
// Restore stack from before allocation of buffer // Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
// Allocate `ndarray.data`. // Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation // `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate) // (for ndarray.create_data() to know how many elements to allocate)
let num_elements = ndarray.create_data(generator, ctx); // NOTE: the strides of `ndarray` has also been set to contiguous in `::create_data()`.
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes) // debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let sizeof_data = let num_elements = ndarray.size(generator, ctx);
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
let expected_ndarray_nbytes = num_elements.mul(ctx, itemsize);
let cmp = expected_ndarray_nbytes.compare(ctx, IntPredicate::UGE, ndarray_nbytes);
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::UGE, cmp.value,
sizeof_data,
ndarray_nbytes,
"",
).unwrap(),
"0:AssertionError", "0:AssertionError",
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(sizeof_data), Some(ndarray_nbytes), None], [Some(expected_ndarray_nbytes.value), Some(ndarray_nbytes.value), None],
ctx.current_loc, ctx.current_loc,
); );
} }
ndarray.create_data(ctx, llvm_elem_ty, num_elements); let ndarray_data = ndarray.instance.get(generator, ctx, |f| f.data);
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 =
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
// NOTE: Currently on `prehead_bb` // NOTE: Currently on `prehead_bb`
ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap();
@ -747,7 +699,7 @@ fn format_rpc_ret<'ctx>(
ctx.builder.position_at_end(head_bb); ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); phi.add_incoming(&[(&ndarray_data.value, prehead_bb)]);
let alloc_size = ctx let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
@ -762,12 +714,12 @@ fn format_rpc_ret<'ctx>(
ctx.builder.position_at_end(alloc_bb); ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T) // Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); let alloc_size = round_up(ctx, alloc_size, itemsize.value);
let alloc_ptr = ctx let alloc_ptr = ctx
.builder .builder
.build_array_alloca( .build_array_alloca(
llvm_elem_ty, dtype_llvm,
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), ctx.builder.build_int_unsigned_div(alloc_size, itemsize.value, "").unwrap(),
"rpc.alloc", "rpc.alloc",
) )
.unwrap(); .unwrap();
@ -777,7 +729,7 @@ fn format_rpc_ret<'ctx>(
ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb); ctx.builder.position_at_end(tail_bb);
ndarray.as_base_value().into() ndarray.instance.value.as_basic_value_enum()
} }
_ => { _ => {