forked from M-Labs/nac3
1
0
Fork 0

core: Fix index-based operations not returning i32

This commit is contained in:
David Mak 2024-03-22 16:10:42 +08:00
parent 4bb0e60981
commit 789bfb5a26
4 changed files with 13 additions and 7 deletions

View File

@ -1326,6 +1326,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None),
)
};
let dim_idx = ctx.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT,

View File

@ -243,13 +243,13 @@ void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
uint64_t* idxs
uint32_t* idxs
) {
uint64_t stride = 1;
for (uint64_t dim = 0; dim < num_dims; dim++) {
uint64_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
idxs[i] = (uint32_t) ((index / stride) % dims[i]);
stride *= dims[i];
}
}

View File

@ -619,7 +619,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
@ -631,8 +632,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
ndarray: NDArrayValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
@ -646,7 +648,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_pi32.into(),
],
false,
);
@ -658,7 +660,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(
llvm_usize,
llvm_i32,
ndarray_num_dims,
"",
).unwrap();

View File

@ -470,6 +470,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize_2 = llvm_usize.array_type(2);
@ -512,7 +513,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
let col_with_offset = ctx.builder
.build_int_add(
col,
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_usize, "").unwrap(),
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
"",
)
.unwrap();