1
0
forked from M-Labs/nac3

core: Fix index-based operations not returning i32

This commit is contained in:
David Mak 2024-03-22 16:10:42 +08:00
parent 4bb0e60981
commit 789bfb5a26
4 changed files with 13 additions and 7 deletions

View File

@ -1326,6 +1326,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None), self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None),
) )
}; };
let dim_idx = ctx.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt = ctx.builder.build_int_compare( let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,

View File

@ -243,13 +243,13 @@ void __nac3_ndarray_calc_nd_indices64(
uint64_t index, uint64_t index,
const uint64_t* dims, const uint64_t* dims,
uint64_t num_dims, uint64_t num_dims,
uint64_t* idxs uint32_t* idxs
) { ) {
uint64_t stride = 1; uint64_t stride = 1;
for (uint64_t dim = 0; dim < num_dims; dim++) { for (uint64_t dim = 0; dim < num_dims; dim++) {
uint64_t i = num_dims - dim - 1; uint64_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0); __builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i]; idxs[i] = (uint32_t) ((index / stride) % dims[i]);
stride *= dims[i]; stride *= dims[i];
} }
} }

View File

@ -619,7 +619,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
.unwrap() .unwrap()
} }
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. /// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
/// ///
/// * `index` - The index to compute the multidimensional index for. /// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
@ -631,8 +632,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_void = ctx.ctx.void_type(); let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
@ -646,7 +648,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
llvm_usize.into(), llvm_usize.into(),
llvm_pusize.into(), llvm_pusize.into(),
llvm_usize.into(), llvm_usize.into(),
llvm_pusize.into(), llvm_pi32.into(),
], ],
false, false,
); );
@ -658,7 +660,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_dims = ndarray.dim_sizes(); let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca( let indices = ctx.builder.build_array_alloca(
llvm_usize, llvm_i32,
ndarray_num_dims, ndarray_num_dims,
"", "",
).unwrap(); ).unwrap();

View File

@ -470,6 +470,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
ncols: IntValue<'ctx>, ncols: IntValue<'ctx>,
offset: IntValue<'ctx>, offset: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize_2 = llvm_usize.array_type(2); let llvm_usize_2 = llvm_usize.array_type(2);
@ -512,7 +513,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
let col_with_offset = ctx.builder let col_with_offset = ctx.builder
.build_int_add( .build_int_add(
col, col,
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_usize, "").unwrap(), ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
"", "",
) )
.unwrap(); .unwrap();