artiq: reimplement reformat_rpc_arg for ndarray

This commit is contained in:
lyken 2024-08-22 13:05:03 +08:00
parent 9b2e933405
commit 121f45279e
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8

View File

@ -1,12 +1,11 @@
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{ classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
NDArrayValue, RangeValue, UntypedArrayLikeAccessor,
},
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size, irrt::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
model::*,
object::{any::AnyObject, ndarray::NDArrayObject},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -20,7 +19,7 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{ use inkwell::{
context::Context, context::Context,
module::Linkage, module::Linkage,
types::{BasicType, IntType}, types::IntType,
values::{BasicValueEnum, PointerValue, StructValue}, values::{BasicValueEnum, PointerValue, StructValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
@ -456,58 +455,41 @@ fn format_rpc_arg<'ctx>(
// NAC3: NDArray = { usize, usize*, T* } // NAC3: NDArray = { usize, usize*, T* }
// libproto_artiq: NDArray = [data[..], dim_sz[..]] // libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type(); let ndarray = AnyObject { ty: arg_ty, value: arg };
let llvm_usize = generator.get_size_type(ctx.ctx); let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let dtype = ctx.get_llvm_type(generator, ndarray.dtype);
let llvm_arg_ty = let ndims = ndarray.ndims_llvm(generator, ctx.ctx);
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
.builder // the reader.
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype));
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let dims_buf_sz = let sizeof_sizet = Int(SizeT).sizeof(generator, ctx.ctx);
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet);
let buffer_size = let sizeof_pdata = Ptr(Any(dtype)).sizeof(generator, ctx.ctx);
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata);
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims);
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata);
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); // buf = { data: void*, shape: [size_t; ndims]; }
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value);
let buf_data = buf;
let buf_shape = buf_data.offset(ctx, sizeof_pdata.value);
call_memcpy_generic( // Write to `buf->data`
ctx, let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr<Any>
buffer.base_ptr(ctx, generator), let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte));
ppdata, buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value);
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin = // Write to `buf->shape`
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape);
call_memcpy_generic( let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte));
ctx, buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value);
pbuffer_dims_begin,
llvm_arg.dim_sizes().base_ptr(ctx, generator),
dims_buf_sz,
llvm_i1.const_zero(),
);
buffer.base_ptr(ctx, generator) buf.value
} }
_ => { _ => {