diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index e2add43..530ce0a 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -552,9 +552,9 @@ pub fn call_j0<'ctx>( /// /// * `num_dims` - An [IntValue] containing the number of dimensions. /// * `dims` - A [PointerValue] to an array containing the size of each dimensions. -pub fn call_ndarray_calc_size<'ctx, 'a>( - generator: &mut dyn CodeGenerator, - ctx: &mut CodeGenContext<'ctx, 'a>, +pub fn call_ndarray_calc_size<'ctx>( + generator: &dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, num_dims: IntValue<'ctx>, dims: PointerValue<'ctx>, ) -> IntValue<'ctx> { @@ -563,7 +563,7 @@ pub fn call_ndarray_calc_size<'ctx, 'a>( let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default()); - let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_size", 64 => "__nac3_ndarray_calc_size64", bw => unreachable!("Unsupported size type bit width: {}", bw) @@ -600,9 +600,9 @@ pub fn call_ndarray_calc_size<'ctx, 'a>( /// `NDArray`. /// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM /// representation of a `list`. -pub fn call_ndarray_init_dims<'ctx, 'a>( - generator: &mut dyn CodeGenerator, - ctx: &mut CodeGenContext<'ctx, 'a>, +pub fn call_ndarray_init_dims<'ctx>( + generator: &dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, ndarray: PointerValue<'ctx>, shape: PointerValue<'ctx>, ) { @@ -616,7 +616,7 @@ pub fn call_ndarray_init_dims<'ctx, 'a>( let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_init_dims", 64 => "__nac3_ndarray_init_dims64", bw => unreachable!("Unsupported size type bit width: {}", bw) @@ -661,9 +661,14 @@ pub fn call_ndarray_init_dims<'ctx, 'a>( ); } -pub fn call_ndarray_calc_nd_indices<'ctx, 'a>( - generator: &mut dyn CodeGenerator, - ctx: &mut CodeGenContext<'ctx, 'a>, +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. +/// +/// * `index` - The index to compute the multidimensional index for. +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +pub fn call_ndarray_calc_nd_indices<'ctx>( + generator: &dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, ndarray: PointerValue<'ctx>, ) -> Result, String> { @@ -675,12 +680,12 @@ pub fn call_ndarray_calc_nd_indices<'ctx, 'a>( let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - let ndarray_calc_nd_indices_dn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { 32 => "__nac3_ndarray_calc_nd_indices", 64 => "__nac3_ndarray_calc_nd_indices64", bw => unreachable!("Unsupported size type bit width: {}", bw) }; - let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_dn_name).unwrap_or_else(|| { + let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { let fn_type = llvm_void.fn_type( &[ llvm_usize.into(), @@ -691,7 +696,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, 'a>( false, ); - ctx.module.add_function(ndarray_calc_nd_indices_dn_name, fn_type, None) + ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) }); let ndarray_num_dims = ctx.build_gep_and_load(