This commit is contained in:
David Mak 2024-08-13 17:30:18 +08:00
parent 9f0df625da
commit 3e666af206
1 changed files with 52 additions and 18 deletions

View File

@ -29,7 +29,10 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use inkwell::types::BasicType;
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::classes::{ArrayLikeIndexer, ArrayLikeValue, NDArrayType};
use nac3core::codegen::llvm_intrinsics;
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher}, hash::{Hash, Hasher},
@ -444,10 +447,11 @@ fn rpc_codegen_callback_fn<'ctx>(
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let int1 = ctx.ctx.bool_type();
let size_type = generator.get_size_type(ctx.ctx);
let int8 = ctx.ctx.i8_type(); let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
@ -541,26 +545,56 @@ fn rpc_codegen_callback_fn<'ctx>(
let arg_slot = let arg_slot =
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap(); ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx
.builder
.build_bitcast(arg_slot, ptr_type, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let arg_slot = if matches!(&*ctx.unifier.get_ty_immutable(*arg_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id()) let arg_slot = if matches!(&*ctx.unifier.get_ty_immutable(*arg_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
{ {
debug_assert_eq!(u64::from(size_type.get_bit_width() / 8), 4); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, *arg_ty);
let llvm_arg_ty =
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), size_type, None);
unsafe { let llvm_usize_sizeof = llvm_arg_ty.size_type().size_of();
ctx.builder let llvm_elem_sizeof = llvm_arg_ty.element_type().size_of().unwrap();
.build_in_bounds_gep(
arg_slot, let dims_buf_sz =
&[size_type.const_int(4, false)], // should be 4 ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
"", let data_buf_sz = ctx
) .builder
.unwrap() .build_int_mul(
} call_ndarray_calc_size(generator, ctx, &llvm_arg.dim_sizes(), (None, None)),
llvm_elem_sizeof,
"",
)
.unwrap();
let buffer_size = ctx.builder.build_int_add(dims_buf_sz, data_buf_sz, "").unwrap();
let buffer =
generator.gen_array_var_alloc(ctx, int8.into(), buffer_size, None).unwrap();
llvm_intrinsics::call_memcpy_generic(
ctx,
buffer.base_ptr(ctx, generator),
llvm_arg.dim_sizes().base_ptr(ctx, generator),
dims_buf_sz,
int1.const_zero(),
);
let pbuffer_data_begin =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &dims_buf_sz, None) };
llvm_intrinsics::call_memcpy_generic(
ctx,
pbuffer_data_begin,
llvm_arg.data().base_ptr(ctx, generator),
data_buf_sz,
int1.const_zero(),
);
buffer.base_ptr(ctx, generator)
} else { } else {
arg_slot ctx.builder
.build_bitcast(arg_slot, ptr_type, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}; };
let arg_ptr = unsafe { let arg_ptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(