From 9587e633472142438266899f06956d056c5dd8fd Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 20 Jun 2024 14:08:10 +0800 Subject: [PATCH] core: add BreakContinueHooks for gen_for_callback --- nac3core/src/codegen/builtin_fns.rs | 4 ++-- nac3core/src/codegen/classes.rs | 2 +- nac3core/src/codegen/irrt/mod.rs | 2 +- nac3core/src/codegen/numpy.rs | 10 +++++----- nac3core/src/codegen/stmt.rs | 27 +++++++++++++++++++++------ 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 5efbe4a5..272a61da 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -725,7 +725,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_int(1, false), (n_sz, false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); @@ -941,7 +941,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_int(1, false), (n_sz, false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 24c53feb..18bf1bb8 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1706,7 +1706,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> ctx, llvm_usize.const_zero(), (len, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let (dim_idx, dim_sz) = unsafe { ( indices.get_unchecked(ctx, generator, &i, None).into_int_value(), diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 8677f085..3ec852aa 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -802,7 +802,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_zero(), (min_ndims, false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let (lhs_dim_sz, rhs_dim_sz) = unsafe { ( diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 7f19f4ed..b082d1ec 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -86,7 +86,7 @@ where ctx, llvm_usize.const_zero(), (shape_len, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); @@ -131,7 +131,7 @@ where ctx, llvm_usize.const_zero(), (shape_len, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let shape_dim = shape_data_fn(generator, ctx, shape, i)?; debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); @@ -334,7 +334,7 @@ where ctx, llvm_usize.const_zero(), (ndarray_num_elems, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; let value = value_fn(generator, ctx, i)?; @@ -1193,7 +1193,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ctx, llvm_usize.const_int(slices.len() as u64, false), (this.load_ndims(ctx), false), - |generator, ctx, idx| { + |generator, ctx, _, idx| { unsafe { let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); @@ -1597,7 +1597,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ctx, llvm_i32.const_zero(), (common_dim, false), - |generator, ctx, i| { + |generator, ctx, _, i| { let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); let ab_idx = generator.gen_array_var_alloc( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index a670f117..8ffc3621 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -462,6 +462,14 @@ pub fn gen_for( Ok(()) } +#[derive(Debug, Clone, Copy)] +pub struct BreakContinueHooks<'ctx> { + /// [`BasicBlock`] to branch to for `break`-ing out of the loop. + pub break_bb: BasicBlock<'ctx>, + /// [`BasicBlock`] to branch to for `continue`-ing to the next iteration in the loop. + pub continue_bb: BasicBlock<'ctx>, +} + /// Generates a C-style `for` construct using lambdas, similar to the following C code: /// /// ```c @@ -489,7 +497,8 @@ where I: Clone, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, + BodyFn: + FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, { let current_bb = ctx.builder.get_insert_block().unwrap(); @@ -520,7 +529,8 @@ where } ctx.builder.position_at_end(body_bb); - body(generator, ctx, loop_var.clone())?; + let hooks = BreakContinueHooks { break_bb: cont_bb, continue_bb: update_bb }; + body(generator, ctx, hooks, loop_var.clone())?; if !ctx.is_terminated() { ctx.builder.build_unconditional_branch(update_bb).unwrap(); } @@ -562,7 +572,12 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( ) -> Result<(), String> where G: CodeGenerator + ?Sized, - BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, + BodyFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks, + IntValue<'ctx>, + ) -> Result<(), String>, { let init_val_t = init_val.get_type(); @@ -584,10 +599,10 @@ where Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) }, - |generator, ctx, i_addr| { + |generator, ctx, hooks, i_addr| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - body(generator, ctx, i) + body(generator, ctx, hooks, i) }, |_, ctx, i_addr| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); @@ -698,7 +713,7 @@ where Ok(cond) }, - |generator, ctx, (i_addr, _)| { + |generator, ctx, _, (i_addr, _)| { let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); body_fn(generator, ctx, i)