forked from M-Labs/nac3
core: add BreakContinueHooks for gen_for_callback
This commit is contained in:
parent
144a3fc426
commit
42c5f906fb
@ -725,7 +725,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
(n_sz, false),
|
(n_sz, false),
|
||||||
|generator, ctx, idx| {
|
|generator, ctx, _, idx| {
|
||||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
@ -941,7 +941,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
(n_sz, false),
|
(n_sz, false),
|
||||||
|generator, ctx, idx| {
|
|generator, ctx, _, idx| {
|
||||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
|
@ -1706,7 +1706,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
(len, false),
|
(len, false),
|
||||||
|generator, ctx, i| {
|
|generator, ctx, _, i| {
|
||||||
let (dim_idx, dim_sz) = unsafe {
|
let (dim_idx, dim_sz) = unsafe {
|
||||||
(
|
(
|
||||||
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
|
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
|
||||||
|
@ -802,7 +802,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
(min_ndims, false),
|
(min_ndims, false),
|
||||||
|generator, ctx, idx| {
|
|generator, ctx, _, idx| {
|
||||||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||||
(
|
(
|
||||||
|
@ -86,7 +86,7 @@ where
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
(shape_len, false),
|
(shape_len, false),
|
||||||
|generator, ctx, i| {
|
|generator, ctx, _, i| {
|
||||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ where
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
(shape_len, false),
|
(shape_len, false),
|
||||||
|generator, ctx, i| {
|
|generator, ctx, _, i| {
|
||||||
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
|
||||||
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
|
||||||
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
|
||||||
@ -334,7 +334,7 @@ where
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
(ndarray_num_elems, false),
|
(ndarray_num_elems, false),
|
||||||
|generator, ctx, i| {
|
|generator, ctx, _, i| {
|
||||||
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
|
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
|
||||||
|
|
||||||
let value = value_fn(generator, ctx, i)?;
|
let value = value_fn(generator, ctx, i)?;
|
||||||
@ -1193,7 +1193,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_usize.const_int(slices.len() as u64, false),
|
llvm_usize.const_int(slices.len() as u64, false),
|
||||||
(this.load_ndims(ctx), false),
|
(this.load_ndims(ctx), false),
|
||||||
|generator, ctx, idx| {
|
|generator, ctx, _, idx| {
|
||||||
unsafe {
|
unsafe {
|
||||||
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
|
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
|
||||||
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
|
||||||
@ -1597,7 +1597,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||||||
ctx,
|
ctx,
|
||||||
llvm_i32.const_zero(),
|
llvm_i32.const_zero(),
|
||||||
(common_dim, false),
|
(common_dim, false),
|
||||||
|generator, ctx, i| {
|
|generator, ctx, _, i| {
|
||||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
||||||
|
|
||||||
let ab_idx = generator.gen_array_var_alloc(
|
let ab_idx = generator.gen_array_var_alloc(
|
||||||
|
@ -462,6 +462,14 @@ pub fn gen_for<G: CodeGenerator>(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct BreakContinueHooks<'ctx> {
|
||||||
|
/// [`BasicBlock`] to branch to for `break`-ing out of the loop.
|
||||||
|
pub break_bb: BasicBlock<'ctx>,
|
||||||
|
/// [`BasicBlock`] to branch to for `continue`-ing to the next iteration in the loop.
|
||||||
|
pub continue_bb: BasicBlock<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
|
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
|
||||||
///
|
///
|
||||||
/// ```c
|
/// ```c
|
||||||
@ -489,7 +497,8 @@ where
|
|||||||
I: Clone,
|
I: Clone,
|
||||||
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
||||||
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
||||||
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
BodyFn:
|
||||||
|
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
|
||||||
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
@ -520,7 +529,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx.builder.position_at_end(body_bb);
|
ctx.builder.position_at_end(body_bb);
|
||||||
body(generator, ctx, loop_var.clone())?;
|
let hooks = BreakContinueHooks { break_bb: cont_bb, continue_bb: update_bb };
|
||||||
|
body(generator, ctx, hooks, loop_var.clone())?;
|
||||||
if !ctx.is_terminated() {
|
if !ctx.is_terminated() {
|
||||||
ctx.builder.build_unconditional_branch(update_bb).unwrap();
|
ctx.builder.build_unconditional_branch(update_bb).unwrap();
|
||||||
}
|
}
|
||||||
@ -562,7 +572,12 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
|
|||||||
) -> Result<(), String>
|
) -> Result<(), String>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
|
BodyFn: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BreakContinueHooks,
|
||||||
|
IntValue<'ctx>,
|
||||||
|
) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
let init_val_t = init_val.get_type();
|
let init_val_t = init_val.get_type();
|
||||||
|
|
||||||
@ -584,10 +599,10 @@ where
|
|||||||
|
|
||||||
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
|
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
|
||||||
},
|
},
|
||||||
|generator, ctx, i_addr| {
|
|generator, ctx, hooks, i_addr| {
|
||||||
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
|
|
||||||
body(generator, ctx, i)
|
body(generator, ctx, hooks, i)
|
||||||
},
|
},
|
||||||
|_, ctx, i_addr| {
|
|_, ctx, i_addr| {
|
||||||
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
@ -698,7 +713,7 @@ where
|
|||||||
|
|
||||||
Ok(cond)
|
Ok(cond)
|
||||||
},
|
},
|
||||||
|generator, ctx, (i_addr, _)| {
|
|generator, ctx, _, (i_addr, _)| {
|
||||||
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
|
|
||||||
body_fn(generator, ctx, i)
|
body_fn(generator, ctx, i)
|
||||||
|
Loading…
Reference in New Issue
Block a user