forked from M-Labs/nac3
core: Fix index-based operations not returning i32
This commit is contained in:
parent
4bb0e60981
commit
789bfb5a26
|
@ -1326,6 +1326,9 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
||||||
self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None),
|
self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
let dim_idx = ctx.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let dim_lt = ctx.builder.build_int_compare(
|
let dim_lt = ctx.builder.build_int_compare(
|
||||||
IntPredicate::SLT,
|
IntPredicate::SLT,
|
||||||
|
|
|
@ -243,13 +243,13 @@ void __nac3_ndarray_calc_nd_indices64(
|
||||||
uint64_t index,
|
uint64_t index,
|
||||||
const uint64_t* dims,
|
const uint64_t* dims,
|
||||||
uint64_t num_dims,
|
uint64_t num_dims,
|
||||||
uint64_t* idxs
|
uint32_t* idxs
|
||||||
) {
|
) {
|
||||||
uint64_t stride = 1;
|
uint64_t stride = 1;
|
||||||
for (uint64_t dim = 0; dim < num_dims; dim++) {
|
for (uint64_t dim = 0; dim < num_dims; dim++) {
|
||||||
uint64_t i = num_dims - dim - 1;
|
uint64_t i = num_dims - dim - 1;
|
||||||
__builtin_assume(dims[i] > 0);
|
__builtin_assume(dims[i] > 0);
|
||||||
idxs[i] = (index / stride) % dims[i];
|
idxs[i] = (uint32_t) ((index / stride) % dims[i]);
|
||||||
stride *= dims[i];
|
stride *= dims[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -619,7 +619,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
|
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
|
||||||
|
/// containing `i32` indices of the flattened index.
|
||||||
///
|
///
|
||||||
/// * `index` - The index to compute the multidimensional index for.
|
/// * `index` - The index to compute the multidimensional index for.
|
||||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||||
|
@ -631,8 +632,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
) -> PointerValue<'ctx> {
|
) -> PointerValue<'ctx> {
|
||||||
let llvm_void = ctx.ctx.void_type();
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
||||||
|
@ -646,7 +648,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
llvm_usize.into(),
|
llvm_usize.into(),
|
||||||
llvm_pusize.into(),
|
llvm_pusize.into(),
|
||||||
llvm_usize.into(),
|
llvm_usize.into(),
|
||||||
llvm_pusize.into(),
|
llvm_pi32.into(),
|
||||||
],
|
],
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
|
@ -658,7 +660,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let ndarray_dims = ndarray.dim_sizes();
|
let ndarray_dims = ndarray.dim_sizes();
|
||||||
|
|
||||||
let indices = ctx.builder.build_array_alloca(
|
let indices = ctx.builder.build_array_alloca(
|
||||||
llvm_usize,
|
llvm_i32,
|
||||||
ndarray_num_dims,
|
ndarray_num_dims,
|
||||||
"",
|
"",
|
||||||
).unwrap();
|
).unwrap();
|
||||||
|
|
|
@ -470,6 +470,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ncols: IntValue<'ctx>,
|
ncols: IntValue<'ctx>,
|
||||||
offset: IntValue<'ctx>,
|
offset: IntValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let llvm_usize_2 = llvm_usize.array_type(2);
|
let llvm_usize_2 = llvm_usize.array_type(2);
|
||||||
|
|
||||||
|
@ -512,7 +513,7 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let col_with_offset = ctx.builder
|
let col_with_offset = ctx.builder
|
||||||
.build_int_add(
|
.build_int_add(
|
||||||
col,
|
col,
|
||||||
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_usize, "").unwrap(),
|
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
Loading…
Reference in New Issue