1
0
forked from M-Labs/nac3

core: Preserve value of variable shadowed by for loop

Previously, the final value of the target expression would be one after
the last element of the loop, which does not match Python's behavior.

This commit fixes this problem while also preserving the last assigned
value of the loop beyond the loop, matching Python's behavior.
This commit is contained in:
David Mak 2023-09-05 12:10:52 +08:00
parent 6805253515
commit e0de82993f
4 changed files with 115 additions and 38 deletions

View File

@ -77,7 +77,8 @@ pub struct CodeGenContext<'ctx, 'a> {
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>, pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
// stores the alloca for variables // stores the alloca for variables
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
// the first one is the test_bb, and the second one is bb after the loop /// The header and exit basic blocks of a loop in this context. See
/// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology.
pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
// unwind target bb // unwind target bb
pub unwind_target: Option<BasicBlock<'ctx>>, pub unwind_target: Option<BasicBlock<'ctx>>,

View File

@ -13,8 +13,8 @@ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
types::BasicTypeEnum, types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate::EQ, IntPredicate,
}; };
use nac3parser::ast::{ use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
@ -107,7 +107,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
); );
// handle negative index // handle negative index
let is_negative = ctx.builder.build_int_compare( let is_negative = ctx.builder.build_int_compare(
inkwell::IntPredicate::SLT, IntPredicate::SLT,
raw_index, raw_index,
generator.get_size_type(ctx.ctx).const_zero(), generator.get_size_type(ctx.ctx).const_zero(),
"is_neg", "is_neg",
@ -120,7 +120,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
// unsigned less than is enough, because negative index after adjustment is // unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp) // bigger than the length (for unsigned cmp)
let bound_check = ctx.builder.build_int_compare( let bound_check = ctx.builder.build_int_compare(
inkwell::IntPredicate::ULT, IntPredicate::ULT,
index, index,
len, len,
"inbound", "inbound",
@ -214,6 +214,26 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
Ok(()) Ok(())
} }
/// Generates a sequence of IR which checks whether [value] does not exceed the upper bound of the
/// range as defined by [stop] and [step].
///
/// Note that the generated IR will **not** check whether value is part of the range or whether
/// value exceeds the lower bound of the range (as evident by the missing `start` argument).
///
/// Returns an [IntValue] representing the result of whether the [value] is in the range.
fn gen_in_range_check<'ctx, 'a>(
ctx: &CodeGenContext<'ctx, 'a>,
value: IntValue<'ctx>,
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "");
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value();
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
}
pub fn gen_for<'ctx, 'a, G: CodeGenerator>( pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
@ -234,49 +254,62 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
// if there is no orelse, we just go to cont_bb // if there is no orelse, we just go to cont_bb
let orelse_bb = let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
// Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); let loop_bb = if is_iterable_range_expr {
ctx.loop_target.replace((body_bb, cont_bb))
} else {
ctx.loop_target.replace((test_bb, cont_bb))
};
let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum( let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(
ctx, ctx,
generator, generator,
iter.custom.unwrap(), iter.custom.unwrap(),
)?; )?;
if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) { if is_iterable_range_expr {
// setup
let iter_val = iter_val.into_pointer_value(); let iter_val = iter_val.into_pointer_value();
// Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
let user_i = generator.gen_store_target(ctx, target, Some("for.user_i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let (start, end, step) = destructure_range(ctx, iter_val); let target_i = generator.gen_store_target(ctx, target, Some("for.target.addr"))?;
ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); let (start, stop, step) = destructure_range(ctx, iter_val);
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(test_bb); ctx.builder.build_store(i, start);
let sign = ctx.builder.build_int_compare(
inkwell::IntPredicate::SGT, // Pre-Loop Checks:
step, // - step == 0 -> ValueError
int32.const_zero(), // - start < stop for step > 0 || start > stop for step < 0
"sign", // TODO: Generate step == 0 -> raise ValueError
);
// add and test
let tmp = ctx.builder.build_int_add(
ctx.builder.build_load(i, "i").into_int_value(),
step,
"start_loop",
);
ctx.builder.build_store(i, tmp);
ctx.builder.build_store(user_i, tmp);
// // if step > 0, continue when i < end
let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1");
// if step < 0, continue when i > end
let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2");
let pos = ctx.builder.build_and(sign, cmp1, "pos");
let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg");
ctx.builder.build_conditional_branch( ctx.builder.build_conditional_branch(
ctx.builder.build_or(pos, neg, "or"), gen_in_range_check(ctx, start, stop, step),
body_bb, body_bb,
orelse_bb, orelse_bb,
); );
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value());
gen_block(generator, ctx, body.iter())?;
// Test if next element is still in range
let next_i = ctx.builder.build_int_add(
ctx.builder.build_load(i, "").into_int_value(),
step,
"next_i",
);
let cond_cont_bb = ctx.ctx.append_basic_block(current, "for.cond.cont");
ctx.builder.build_conditional_branch(
gen_in_range_check(ctx, next_i, stop, step),
cond_cont_bb,
orelse_bb,
);
ctx.builder.position_at_end(cond_cont_bb);
ctx.builder.build_store(i, next_i);
} else { } else {
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("for.counter.addr"))?; let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("for.counter.addr"))?;
// counter = -1 // counter = -1
@ -288,30 +321,39 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
) )
.into_int_value(); .into_int_value();
ctx.builder.build_unconditional_branch(test_bb); ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
let tmp = ctx.builder.build_load(counter, "i").into_int_value(); let tmp = ctx.builder.build_load(counter, "i").into_int_value();
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc"); let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc");
ctx.builder.build_store(counter, tmp); ctx.builder.build_store(counter, tmp);
let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, len, "cmp"); let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, len, "cmp");
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb); ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb);
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
let arr_ptr = ctx let arr_ptr = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero]) .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero])
.into_pointer_value(); .into_pointer_value();
let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); let val = ctx.build_gep_and_load(arr_ptr, &[tmp]);
generator.gen_assign(ctx, target, val.into())?; generator.gen_assign(ctx, target, val.into())?;
gen_block(generator, ctx, body.iter())?;
} }
gen_block(generator, ctx, body.iter())?;
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in var_assignment.iter() {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
} }
} }
if !ctx.is_terminated() { if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(test_bb); if is_iterable_range_expr {
ctx.builder.build_unconditional_branch(body_bb);
} else {
ctx.builder.build_unconditional_branch(test_bb);
}
} }
if !orelse.is_empty() { if !orelse.is_empty() {
ctx.builder.position_at_end(orelse_bb); ctx.builder.position_at_end(orelse_bb);
gen_block(generator, ctx, orelse.iter())?; gen_block(generator, ctx, orelse.iter())?;
@ -319,12 +361,20 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.build_unconditional_branch(cont_bb);
} }
} }
for (k, (_, _, counter)) in var_assignment.iter() { for (k, (_, _, counter)) in var_assignment.iter() {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 { if counter != counter2 {
*static_val = None; *static_val = None;
} }
} }
// Clear test_bb if unused
if is_iterable_range_expr {
ctx.builder.position_at_end(test_bb);
ctx.builder.build_unreachable();
}
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
ctx.loop_target = loop_bb; ctx.loop_target = loop_bb;
} else { } else {
@ -850,7 +900,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
.builder .builder
.build_load(exn_type.into_pointer_value(), "expected_id") .build_load(exn_type.into_pointer_value(), "expected_id")
.into_int_value(); .into_int_value();
let result = ctx.builder.build_int_compare(EQ, actual_id, expected_id, "exncheck"); let result = ctx.builder.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck");
ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont); ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont);
dispatcher_end = dispatcher_cont; dispatcher_end = dispatcher_cont;
} else { } else {

View File

@ -0,0 +1,12 @@
# For Loop using a decreasing range() expression as its iterable
@extern
def output_int32(x: int32):
...
def run() -> int32:
i = 0
for i in range(10, 0, -1):
output_int32(i)
output_int32(i)
return 0

View File

@ -0,0 +1,14 @@
# For Loop using an range() expression as its iterable, additionally reassigning the target on each iteration
@extern
def output_int32(x: int32):
...
def run() -> int32:
i = 0
for i in range(10):
output_int32(i)
i = 0
output_int32(i)
output_int32(i)
return 0