forked from M-Labs/nac3
artiq: reimplement reformat_rpc_arg for ndarrays
This commit is contained in:
parent
31931b7b26
commit
ad5506bff1
|
@ -1,12 +1,11 @@
|
|||
use nac3core::{
|
||||
codegen::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
|
||||
NDArrayValue, RangeValue, UntypedArrayLikeAccessor,
|
||||
},
|
||||
classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
|
||||
expr::{destructure_range, gen_call},
|
||||
irrt::call_ndarray_calc_size,
|
||||
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
||||
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
|
||||
model::*,
|
||||
object::{any::AnyObject, ndarray::NDArrayObject},
|
||||
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
|
@ -20,7 +19,7 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
|||
use inkwell::{
|
||||
context::Context,
|
||||
module::Linkage,
|
||||
types::{BasicType, IntType},
|
||||
types::IntType,
|
||||
values::{BasicValueEnum, PointerValue, StructValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
|
@ -456,58 +455,41 @@ fn format_rpc_arg<'ctx>(
|
|||
// NAC3: NDArray = { usize, usize*, T* }
|
||||
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let ndarray = AnyObject { ty: arg_ty, value: arg };
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let llvm_arg_ty =
|
||||
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
|
||||
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
|
||||
let dtype = ctx.get_llvm_type(generator, ndarray.dtype);
|
||||
let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
// `ndarray.data` is possibly not contiguous. We need to force it to be continuous,
|
||||
// and we might have to copy the whole ndarray.
|
||||
let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype));
|
||||
|
||||
let dims_buf_sz =
|
||||
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||
let sizeof_sizet = Int(SizeT).sizeof(generator, ctx.ctx);
|
||||
let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet);
|
||||
|
||||
let buffer_size =
|
||||
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
||||
let sizeof_pdata = Ptr(Any(dtype)).sizeof(generator, ctx.ctx);
|
||||
let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata);
|
||||
|
||||
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
||||
let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims);
|
||||
let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata);
|
||||
|
||||
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
|
||||
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
|
||||
// buf = { data: void*, shape: [size_t; ndims]; }
|
||||
let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value);
|
||||
let buf_data = buf;
|
||||
let buf_shape = buf_data.offset(ctx, sizeof_pdata.value);
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
buffer.base_ptr(ctx, generator),
|
||||
ppdata,
|
||||
llvm_pdata_sizeof,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
// Write to `buf->data`
|
||||
let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr<Any>
|
||||
let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte));
|
||||
buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value);
|
||||
|
||||
let pbuffer_dims_begin =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
pbuffer_dims_begin,
|
||||
llvm_arg.dim_sizes().base_ptr(ctx, generator),
|
||||
dims_buf_sz,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
// Write to `buf->shape`
|
||||
let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
|
||||
let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte));
|
||||
buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value);
|
||||
|
||||
buffer.base_ptr(ctx, generator)
|
||||
buf.value
|
||||
}
|
||||
|
||||
_ => {
|
||||
|
|
Loading…
Reference in New Issue