1
0
forked from M-Labs/nac3

core: Add BreakContinueHooks for gen_for_callback

This commit is contained in:
David Mak 2024-07-02 19:05:00 +08:00
parent 13beeaa2bf
commit 075536d7bd
5 changed files with 32 additions and 15 deletions

View File

@ -725,7 +725,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
@ -941,7 +941,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();

View File

@ -1719,7 +1719,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(len, false), (len, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let (dim_idx, dim_sz) = unsafe { let (dim_idx, dim_sz) = unsafe {
( (
indices.get_unchecked(ctx, generator, &i, None).into_int_value(), indices.get_unchecked(ctx, generator, &i, None).into_int_value(),

View File

@ -800,7 +800,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(min_ndims, false), (min_ndims, false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe { let (lhs_dim_sz, rhs_dim_sz) = unsafe {
( (

View File

@ -86,7 +86,7 @@ where
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?; let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
@ -131,7 +131,7 @@ where
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?; let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
@ -382,7 +382,7 @@ where
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(ndarray_num_elems, false), (ndarray_num_elems, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
let value = value_fn(generator, ctx, i)?; let value = value_fn(generator, ctx, i)?;
@ -1243,7 +1243,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_int(slices.len() as u64, false), llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
unsafe { unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
@ -1647,7 +1647,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
ctx, ctx,
llvm_i32.const_zero(), llvm_i32.const_zero(),
(common_dim, false), (common_dim, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
let ab_idx = generator.gen_array_var_alloc( let ab_idx = generator.gen_array_var_alloc(

View File

@ -464,6 +464,16 @@ pub fn gen_for<G: CodeGenerator>(
Ok(()) Ok(())
} }
#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
pub struct BreakContinueHooks<'ctx> {
/// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop.
pub exit_bb: BasicBlock<'ctx>,
/// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration
/// of the loop.
pub latch_bb: BasicBlock<'ctx>,
}
/// Generates a C-style `for` construct using lambdas, similar to the following C code: /// Generates a C-style `for` construct using lambdas, similar to the following C code:
/// ///
/// ```c /// ```c
@ -491,7 +501,8 @@ where
I: Clone, I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, BodyFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
@ -522,7 +533,8 @@ where
} }
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
body(generator, ctx, loop_var.clone())?; let hooks = BreakContinueHooks { exit_bb: cont_bb, latch_bb: update_bb };
body(generator, ctx, hooks, loop_var.clone())?;
if !ctx.is_terminated() { if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(update_bb).unwrap(); ctx.builder.build_unconditional_branch(update_bb).unwrap();
} }
@ -564,7 +576,12 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks,
IntValue<'ctx>,
) -> Result<(), String>,
{ {
let init_val_t = init_val.get_type(); let init_val_t = init_val.get_type();
@ -586,10 +603,10 @@ where
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
}, },
|generator, ctx, i_addr| { |generator, ctx, hooks, i_addr| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body(generator, ctx, i) body(generator, ctx, hooks, i)
}, },
|_, ctx, i_addr| { |_, ctx, i_addr| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
@ -700,7 +717,7 @@ where
Ok(cond) Ok(cond)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, _, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body_fn(generator, ctx, i) body_fn(generator, ctx, i)