From 121f45279edcf3639223974bd59880f3e6114975 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 22 Aug 2024 13:05:03 +0800 Subject: [PATCH] artiq: reimplement reformat_rpc_arg for ndarray --- nac3artiq/src/codegen.rs | 80 ++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 49 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index d5a53b16..0e6f6b8c 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,12 +1,11 @@ use nac3core::{ codegen::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, - NDArrayValue, RangeValue, UntypedArrayLikeAccessor, - }, + classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor}, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, - llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, + llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, + model::*, + object::{any::AnyObject, ndarray::NDArrayObject}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, CodeGenContext, CodeGenerator, }, @@ -20,7 +19,7 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use inkwell::{ context::Context, module::Linkage, - types::{BasicType, IntType}, + types::IntType, values::{BasicValueEnum, PointerValue, StructValue}, AddressSpace, IntPredicate, }; @@ -456,58 +455,41 @@ fn format_rpc_arg<'ctx>( // NAC3: NDArray = { usize, usize*, T* } // libproto_artiq: NDArray = [data[..], dim_sz[..]] - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let ndarray = AnyObject { ty: arg_ty, value: arg }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_arg_ty = - NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); - let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + let dtype = ctx.get_llvm_type(generator, ndarray.dtype); + let ndims = ndarray.ndims_llvm(generator, ctx.ctx); - let llvm_usize_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") - .unwrap(); - let llvm_pdata_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(), - llvm_usize, - "", - ) - .unwrap(); + // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for + // the reader. + let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype)); - let dims_buf_sz = - ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + let sizeof_sizet = Int(SizeT).sizeof(generator, ctx.ctx); + let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet); - let buffer_size = - ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); + let sizeof_pdata = Ptr(Any(dtype)).sizeof(generator, ctx.ctx); + let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata); - let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); + let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims); + let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata); - let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); - ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); + // buf = { data: void*, shape: [size_t; ndims]; } + let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value); + let buf_data = buf; + let buf_shape = buf_data.offset(ctx, sizeof_pdata.value); - call_memcpy_generic( - ctx, - buffer.base_ptr(ctx, generator), - ppdata, - llvm_pdata_sizeof, - llvm_i1.const_zero(), - ); + // Write to `buf->data` + let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr + let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte)); + buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value); - let pbuffer_dims_begin = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; - call_memcpy_generic( - ctx, - pbuffer_dims_begin, - llvm_arg.dim_sizes().base_ptr(ctx, generator), - dims_buf_sz, - llvm_i1.const_zero(), - ); + // Write to `buf->shape` + let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte)); + buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value); - buffer.base_ptr(ctx, generator) + buf.value } _ => {