Fix IR generation of for loop containing break/continue #345

Merged
sb10q merged 1 commits from fix/loop-break-continue into master 2023-11-01 13:21:28 +08:00
2 changed files with 63 additions and 38 deletions

View File

@ -253,21 +253,22 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
let body_bb = ctx.ctx.append_basic_block(current, "for.body"); let body_bb = ctx.ctx.append_basic_block(current, "for.body");
let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
// if there is no orelse, we just go to cont_bb // if there is no orelse, we just go to cont_bb
let orelse_bb = let orelse_bb = if orelse.is_empty() {
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; cont_bb
} else {
ctx.ctx.append_basic_block(current, "for.orelse")
};
// Whether the iterable is a range() expression // Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
// The target BB of the loop backedge // The BB containing the increment expression
let backedge_bb_target = if is_iterable_range_expr { let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
body_bb // The BB containing the loop condition check
} else { let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
ctx.ctx.append_basic_block(current, "for.cond")
};
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((backedge_bb_target, cont_bb)); let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum( let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(
ctx, ctx,
@ -294,35 +295,35 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
[None, None, None], [None, None, None],
ctx.current_loc ctx.current_loc
); );
ctx.builder.build_unconditional_branch(cond_bb);
ctx.builder.build_conditional_branch( {
gen_in_range_check(ctx, start, stop, step), ctx.builder.position_at_end(cond_bb);
body_bb, ctx.builder.build_conditional_branch(
orelse_bb, gen_in_range_check(
ctx,
ctx.builder.build_load(i, "").into_int_value(),
stop,
step
),
body_bb,
orelse_bb,
);
}
ctx.builder.position_at_end(incr_bb);
let next_i = ctx.builder.build_int_add(
ctx.builder.build_load(i, "").into_int_value(),
step,
"inc",
); );
ctx.builder.build_store(i, next_i);
ctx.builder.build_unconditional_branch(cond_bb);
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value()); ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value());
gen_block(generator, ctx, body.iter())?; gen_block(generator, ctx, body.iter())?;
// Test if next element is still in range
let next_i = ctx.builder.build_int_add(
ctx.builder.build_load(i, "").into_int_value(),
step,
"next_i",
);
let cond_cont_bb = ctx.ctx.append_basic_block(current, "for.cond.cont");
ctx.builder.build_conditional_branch(
gen_in_range_check(ctx, next_i, stop, step),
cond_cont_bb,
orelse_bb,
);
ctx.builder.position_at_end(cond_cont_bb);
ctx.builder.build_store(i, next_i);
} else { } else {
let test_bb = backedge_bb_target;
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?; let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero()); ctx.builder.build_store(index_addr, size_t.const_zero());
let len = ctx let len = ctx
@ -332,24 +333,27 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
Some("len") Some("len")
) )
.into_int_value(); .into_int_value();
ctx.builder.build_unconditional_branch(test_bb); ctx.builder.build_unconditional_branch(cond_bb);
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(cond_bb);
let index = ctx.builder.build_load(index_addr, "for.index").into_int_value(); let index = ctx.builder.build_load(index_addr, "for.index").into_int_value();
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond"); let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond");
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb); ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb);
ctx.builder.position_at_end(incr_bb);
let index = ctx.builder.build_load(index_addr, "").into_int_value();
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc");
ctx.builder.build_store(index_addr, inc);
ctx.builder.build_unconditional_branch(cond_bb);
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
let arr_ptr = ctx let arr_ptr = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.into_pointer_value(); .into_pointer_value();
let index = ctx.builder.build_load(index_addr, "for.index").into_int_value();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val")); let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
generator.gen_assign(ctx, target, val.into())?; generator.gen_assign(ctx, target, val.into())?;
gen_block(generator, ctx, body.iter())?; gen_block(generator, ctx, body.iter())?;
let index = ctx.builder.build_load(index_addr, "for.index").into_int_value();
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "");
ctx.builder.build_store(index_addr, inc);
} }
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in var_assignment.iter() {
@ -360,7 +364,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
} }
if !ctx.is_terminated() { if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(backedge_bb_target); ctx.builder.build_unconditional_branch(incr_bb);
} }
if !orelse.is_empty() { if !orelse.is_empty() {

View File

@ -0,0 +1,21 @@
@extern
def output_int32(x: int32):
...
def run() -> int32:
for i in range(4):
output_int32(i)
if i < 2:
continue
else:
break
n = [0, 1, 2, 3]
for i in n:
output_int32(i)
if i < 2:
continue
else:
break
return 0