diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 489115d1..9db031d2 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -29,7 +29,10 @@ use pyo3::{ use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use inkwell::types::BasicType; use itertools::Itertools; +use nac3core::codegen::classes::{ArrayLikeIndexer, ArrayLikeValue, NDArrayType}; +use nac3core::codegen::llvm_intrinsics; use std::{ collections::{hash_map::DefaultHasher, HashMap}, hash::{Hash, Hasher}, @@ -444,10 +447,11 @@ fn rpc_codegen_callback_fn<'ctx>( args: Vec<(Option, ValueEnum<'ctx>)>, generator: &mut dyn CodeGenerator, ) -> Result>, String> { - let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let size_type = generator.get_size_type(ctx.ctx); + let int1 = ctx.ctx.bool_type(); let int8 = ctx.ctx.i8_type(); let int32 = ctx.ctx.i32_type(); + let size_type = generator.get_size_type(ctx.ctx); + let ptr_type = int8.ptr_type(AddressSpace::default()); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let service_id = int32.const_int(fun.1 .0 as u64, false); @@ -541,26 +545,56 @@ fn rpc_codegen_callback_fn<'ctx>( let arg_slot = generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); ctx.builder.build_store(arg_slot, *arg).unwrap(); - let arg_slot = ctx - .builder - .build_bitcast(arg_slot, ptr_type, "rpc.arg") - .map(BasicValueEnum::into_pointer_value) - .unwrap(); let arg_slot = if matches!(&*ctx.unifier.get_ty_immutable(*arg_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id()) { - debug_assert_eq!(u64::from(size_type.get_bit_width() / 8), 4); + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, *arg_ty); + let llvm_arg_ty = + NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); + let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), size_type, None); - unsafe { - ctx.builder - .build_in_bounds_gep( - arg_slot, - &[size_type.const_int(4, false)], // should be 4 - "", - ) - .unwrap() - } + let llvm_usize_sizeof = llvm_arg_ty.size_type().size_of(); + let llvm_elem_sizeof = llvm_arg_ty.element_type().size_of().unwrap(); + + let dims_buf_sz = + ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + let data_buf_sz = ctx + .builder + .build_int_mul( + call_ndarray_calc_size(generator, ctx, &llvm_arg.dim_sizes(), (None, None)), + llvm_elem_sizeof, + "", + ) + .unwrap(); + + let buffer_size = ctx.builder.build_int_add(dims_buf_sz, data_buf_sz, "").unwrap(); + + let buffer = + generator.gen_array_var_alloc(ctx, int8.into(), buffer_size, None).unwrap(); + + llvm_intrinsics::call_memcpy_generic( + ctx, + buffer.base_ptr(ctx, generator), + llvm_arg.dim_sizes().base_ptr(ctx, generator), + dims_buf_sz, + int1.const_zero(), + ); + + let pbuffer_data_begin = + unsafe { buffer.ptr_offset_unchecked(ctx, generator, &dims_buf_sz, None) }; + llvm_intrinsics::call_memcpy_generic( + ctx, + pbuffer_data_begin, + llvm_arg.data().base_ptr(ctx, generator), + data_buf_sz, + int1.const_zero(), + ); + + buffer.base_ptr(ctx, generator) } else { - arg_slot + ctx.builder + .build_bitcast(arg_slot, ptr_type, "rpc.arg") + .map(BasicValueEnum::into_pointer_value) + .unwrap() }; let arg_ptr = unsafe { ctx.builder.build_gep(