From 075536d7bddc2ac3c6879b2056e10e90f32a68b9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 2 Jul 2024 19:05:00 +0800 Subject: [PATCH] core: Add BreakContinueHooks for gen_for_callback --- nac3core/src/codegen/builtin_fns.rs | 4 ++-- nac3core/src/codegen/classes.rs | 2 +- nac3core/src/codegen/irrt/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 10 +++++----- nac3core/src/codegen/stmt.rs | 29 +++++++++++++++++++++++------ 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 5efbe4a5..272a61da 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -725,7 +725,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_int(1, false), (n_sz, false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); @@ -941,7 +941,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_int(1, false), (n_sz, false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index cdb1841e..d39b55ca 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1719,7 +1719,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> ctx, llvm_usize.const_zero(), (len, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let (dim_idx, dim_sz) = unsafe { ( indices.get_unchecked(ctx, generator, &i, None).into_int_value(), diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 8bab9854..755bcf57 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -800,7 +800,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_zero(), (min_ndims, false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let (lhs_dim_sz, rhs_dim_sz) = unsafe { ( diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 4db8a81f..fee016af 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -86,7 +86,7 @@ where ctx, llvm_usize.const_zero(), (shape_len, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); @@ -131,7 +131,7 @@ where ctx, llvm_usize.const_zero(), (shape_len, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); @@ -382,7 +382,7 @@ where ctx, llvm_usize.const_zero(), (ndarray_num_elems, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; let value = value_fn(generator, ctx, i)?; @@ -1243,7 +1243,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_int(slices.len() as u64, false), (this.load_ndims(ctx), false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { unsafe { let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); @@ -1647,7 +1647,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ctx, llvm_i32.const_zero(), (common_dim, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); let ab_idx = generator.gen_array_var_alloc( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 94266ca2..23ba4352 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -464,6 +464,16 @@ pub fn gen_for( Ok(()) } +#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] +pub struct BreakContinueHooks<'ctx> { + /// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop. + pub exit_bb: BasicBlock<'ctx>, + + /// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration + /// of the loop. + pub latch_bb: BasicBlock<'ctx>, +} + /// Generates a C-style `for` construct using lambdas, similar to the following C code: /// /// ```c @@ -491,7 +501,8 @@ where I: Clone, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, + BodyFn: + FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, { let current_bb = ctx.builder.get_insert_block().unwrap(); @@ -522,7 +533,8 @@ where } ctx.builder.position_at_end(body_bb); - body(generator, ctx, loop_var.clone())?; + let hooks = BreakContinueHooks { exit_bb: cont_bb, latch_bb: update_bb }; + body(generator, ctx, hooks, loop_var.clone())?; if !ctx.is_terminated() { ctx.builder.build_unconditional_branch(update_bb).unwrap(); } @@ -564,7 +576,12 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( ) -> Result<(), String> where G: CodeGenerator + ?Sized, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, + BodyFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks, + IntValue<'ctx>, + ) -> Result<(), String>, { let init_val_t = init_val.get_type(); @@ -586,10 +603,10 @@ where Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) }, - |generator, ctx, i_addr| { + |generator, ctx, hooks, i_addr| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - body(generator, ctx, i) + body(generator, ctx, hooks, i) }, |_, ctx, i_addr| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); @@ -700,7 +717,7 @@ where Ok(cond) }, - |generator, ctx, (i_addr, _)| { + |generator, ctx, _, (i_addr, _)| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); body_fn(generator, ctx, i)