core: fix new irrt ndarray issues

This commit is contained in:
lyken 2024-07-14 23:40:19 +08:00
parent d92cccb85e
commit 0cc7e41c6f
3 changed files with 26 additions and 10 deletions

View File

@ -204,7 +204,7 @@ fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_util_assert_shape_no_negative"),
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_set_strides_by_shape"),
)
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.returning_void();
@ -217,12 +217,9 @@ fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
) -> Int<'ctx> {
let sizet = IntModel(generator.get_size_type(ctx.ctx));
FunctionBuilder::begin(
ctx,
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_util_assert_shape_no_negative"),
)
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.returning("nbytes", &sizet)
FunctionBuilder::begin(ctx, &get_sized_dependent_function_name(sizet, "__nac3_ndarray_nbytes"))
.arg("ndarray", &PointerModel(StructModel(NpArray { sizet })), ndarray_ptr)
.returning("nbytes", &sizet)
}
pub fn call_nac3_ndarray_fill_generic<'ctx, G: CodeGenerator + ?Sized>(

View File

@ -77,4 +77,16 @@ impl<'ctx> OpaquePointer<'ctx> {
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) {
ctx.builder.build_store(self.0, value).unwrap();
}
pub fn from_ptr(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) -> Self {
let ptr = ctx
.builder
.build_pointer_cast(
ptr,
ctx.ctx.i8_type().ptr_type(AddressSpace::default()),
"opaque.from_ptr",
)
.unwrap();
OpaquePointer(ptr)
}
}

View File

@ -55,9 +55,16 @@ where
let ndarray_ptr = call_ndarray_empty_impl(generator, ctx, elem_ty, shape, shape_ty, name)?;
// NOTE: fill_value's type is not checked!! so be careful with logics
let fill_value_ptr =
OpaquePointer(ctx.builder.build_alloca(fill_value.get_type(), "fill_value_ptr").unwrap());
fill_value_ptr.store(ctx, fill_value);
// Allocate fill_value on the stack and give the corresponding stack pointer
// to call_nac3_ndarray_fill_generic
let fill_value_ptr = ctx.builder.build_alloca(fill_value.get_type(), "fill_value_ptr").unwrap();
ctx.builder.build_store(fill_value_ptr, fill_value).unwrap();
// Opaque-ize fill_value_ptr (turning it into `i8*`) before passing
// to call_nac3_ndarray_fill_generic
let fill_value_ptr = OpaquePointer::from_ptr(ctx, fill_value_ptr);
call_nac3_ndarray_fill_generic(generator, ctx, &ndarray_ptr, &fill_value_ptr);
Ok(ndarray_ptr)