From 6d171ec28443244193674c8f5d71c4dfc0b6c748 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 25 Jul 2024 15:54:39 +0800 Subject: [PATCH] core: Add label name and hooks to gen_for functions --- nac3core/src/codegen/builtin_fns.rs | 1 + nac3core/src/codegen/classes.rs | 1 + nac3core/src/codegen/expr.rs | 2 ++ nac3core/src/codegen/irrt/mod.rs | 1 + nac3core/src/codegen/numpy.rs | 14 +++++++++++--- nac3core/src/codegen/stmt.rs | 29 +++++++++++++++++++++-------- 6 files changed, 37 insertions(+), 11 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 63078107b..abe920566 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -863,6 +863,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( gen_for_callback_incrementing( generator, ctx, + None, llvm_int64.const_int(1, false), (n_sz, false), |generator, ctx, _, idx| { diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index d39b55ca6..52e9cca01 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1717,6 +1717,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (len, false), |generator, ctx, _, i| { diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3c3f898f7..c00e944a8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1334,6 +1334,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (int_val, false), |generator, ctx, _, i| { @@ -1944,6 +1945,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (left_val.load_size(ctx, None), false), |generator, ctx, hooks, i| { diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 755bcf57d..5866f7ee3 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -798,6 +798,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (min_ndims, false), |generator, ctx, _, idx| { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 3c307ec1c..92e070599 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -86,6 +86,7 @@ where gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (shape_len, false), |generator, ctx, _, i| { @@ -131,6 +132,7 @@ where gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (shape_len, false), |generator, ctx, _, i| { @@ -382,6 +384,7 @@ where gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_zero(), (ndarray_num_elems, false), |generator, ctx, _, i| { @@ -703,11 +706,12 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( gen_for_range_callback( generator, ctx, + None, true, |_, _| Ok(llvm_usize.const_zero()), (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, i| { + |generator, ctx, _, i| { let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); let dst_ptr = @@ -943,11 +947,12 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( gen_for_range_callback( generator, ctx, + None, true, |_, _| Ok(llvm_usize.const_zero()), (|_, _| Ok(stop), false), |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _| { + |generator, ctx, _, _| { let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) .ptr_type(AddressSpace::default()); @@ -1130,11 +1135,12 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( gen_for_range_callback( generator, ctx, + None, false, |_, _| Ok(start), (|_, _| Ok(stop), true), |_, _| Ok(step), - |generator, ctx, src_i| { + |generator, ctx, _, src_i| { // Calculate the offset of the active slice let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); let dst_i = @@ -1247,6 +1253,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( gen_for_callback_incrementing( generator, ctx, + None, llvm_usize.const_int(slices.len() as u64, false), (this.load_ndims(ctx), false), |generator, ctx, _, idx| { @@ -1651,6 +1658,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( gen_for_callback_incrementing( generator, ctx, + None, llvm_i32.const_zero(), (common_dim, false), |generator, ctx, _, i| { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index cb013d852..ee5819597 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -494,6 +494,7 @@ pub struct BreakContinueHooks<'ctx> { pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, + label: Option<&str>, init: InitFn, cond: CondFn, body: BodyFn, @@ -508,14 +509,16 @@ where FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, { + let label = label.unwrap_or("for"); + let current_bb = ctx.builder.get_insert_block().unwrap(); - let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init"); + let init_bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("{label}.init")); // The BB containing the loop condition check - let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond"); - let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body"); + let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, &format!("{label}.cond")); + let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, &format!("{label}.body")); // The BB containing the increment expression - let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update"); - let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end"); + let update_bb = ctx.ctx.insert_basic_block_after(body_bb, &format!("{label}.update")); + let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, &format!("{label}.end")); // store loop bb information and restore it later let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); @@ -572,6 +575,7 @@ where pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, + label: Option<&str>, init_val: IntValue<'ctx>, max_val: (IntValue<'ctx>, bool), body: BodyFn, @@ -591,6 +595,7 @@ where gen_for_callback( generator, ctx, + label, |generator, ctx| { let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; ctx.builder.build_store(i_addr, init_val).unwrap(); @@ -642,9 +647,11 @@ where /// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// iterable. This value will be extended to the size of `start`. /// - `body_fn`: A lambda of IR statements within the loop body. +#[allow(clippy::too_many_arguments)] pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, + label: Option<&str>, is_unsigned: bool, start_fn: StartFn, (stop_fn, stop_inclusive): (StopFn, bool), @@ -656,13 +663,19 @@ where StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, String>, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, + BodyFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks, + IntValue<'ctx>, + ) -> Result<(), String>, { let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap(); gen_for_callback( generator, ctx, + label, |generator, ctx| { let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?; @@ -720,10 +733,10 @@ where Ok(cond) }, - |generator, ctx, _, (i_addr, _)| { + |generator, ctx, hooks, (i_addr, _)| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - body_fn(generator, ctx, i) + body_fn(generator, ctx, hooks, i) }, |generator, ctx, (i_addr, _)| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();