From 789bfb5a264841fbffa9cc0ec7d9bca6eef43819 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 22 Mar 2024 16:10:42 +0800 Subject: [PATCH] core: Fix index-based operations not returning i32 --- nac3core/src/codegen/classes.rs | 3 +++ nac3core/src/codegen/irrt/irrt.c | 4 ++-- nac3core/src/codegen/irrt/mod.rs | 10 ++++++---- nac3core/src/codegen/numpy.rs | 3 ++- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 3b23999a..371e0b01 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1326,6 +1326,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None), ) }; + let dim_idx = ctx.builder + .build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "") + .unwrap(); let dim_lt = ctx.builder.build_int_compare( IntPredicate::SLT, diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index bbe27ce8..fda93121 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -243,13 +243,13 @@ void __nac3_ndarray_calc_nd_indices64( uint64_t index, const uint64_t* dims, uint64_t num_dims, - uint64_t* idxs + uint32_t* idxs ) { uint64_t stride = 1; for (uint64_t dim = 0; dim < num_dims; dim++) { uint64_t i = num_dims - dim - 1; __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; + idxs[i] = (uint32_t) ((index / stride) % dims[i]); stride *= dims[i]; } } diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 0008d7b9..bd5a62f4 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -619,7 +619,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>( .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] +/// containing `i32` indices of the flattened index. /// /// * `index` - The index to compute the multidimensional index for. /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an @@ -631,8 +632,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ndarray: NDArrayValue<'ctx>, ) -> PointerValue<'ctx> { let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { @@ -646,7 +648,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), - llvm_pusize.into(), + llvm_pi32.into(), ], false, ); @@ -658,7 +660,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( let ndarray_dims = ndarray.dim_sizes(); let indices = ctx.builder.build_array_alloca( - llvm_usize, + llvm_i32, ndarray_num_dims, "", ).unwrap(); diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 02d5d5b0..ff5f5307 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -470,6 +470,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( ncols: IntValue<'ctx>, offset: IntValue<'ctx>, ) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize_2 = llvm_usize.array_type(2); @@ -512,7 +513,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( let col_with_offset = ctx.builder .build_int_add( col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_usize, "").unwrap(), + ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), "", ) .unwrap();