core: Preserve value of variable shadowed by for loop

Previously, the final value of the target expression would be one after
the last element of the loop, which does not match Python's behavior.

This commit fixes this problem while also preserving the last assigned
value of the loop beyond the loop, matching Python's behavior.
David Mak 2023-09-05 12:10:52 +08:00
parent be5775bbd5
commit 7b9f8e8aaa
4 changed files with 107 additions and 39 deletions

View File

@ -77,7 +77,8 @@ pub struct CodeGenContext<'ctx, 'a> {
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
// stores the alloca for variables
pub init_bb: BasicBlock<'ctx>,
// the first one is the test_bb, and the second one is bb after the loop
/// The header and exit basic blocks of a loop in this context. See
/// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology.
pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
// unwind target bb
pub unwind_target: Option<BasicBlock<'ctx>>,

View File

@ -13,8 +13,8 @@ use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
IntPredicate::EQ,
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate,
};
use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
@ -107,7 +107,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
);
// handle negative index
let is_negative = ctx.builder.build_int_compare(
inkwell::IntPredicate::SLT,
IntPredicate::SLT,
raw_index,
generator.get_size_type(ctx.ctx).const_zero(),
"is_neg",
@ -120,7 +120,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
// unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp)
let bound_check = ctx.builder.build_int_compare(
inkwell::IntPredicate::ULT,
IntPredicate::ULT,
index,
len,
"inbound",
@ -213,6 +213,26 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
Ok(())
}
/// Generates a sequence of IR which checks whether [value] does not exceed the upper bound of the
/// range as defined by [stop] and [step].
///
/// Note that the generated IR will **not** check whether value is part of the range or whether
/// value exceeds the lower bound of the range (as evident by the missing `start` argument).
///
/// Returns an [IntValue] representing the result of whether the [value] is in the range.
fn gen_in_range_check<'ctx, 'a>(
ctx: &CodeGenContext<'ctx, 'a>,
value: IntValue<'ctx>,
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "");
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value();
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
}
pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
@ -233,51 +253,62 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
// if there is no orelse, we just go to cont_bb
let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
// Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
// store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((test_bb, cont_bb));
let loop_bb = if is_iterable_range_expr {
ctx.loop_target.replace((body_bb, cont_bb))
} else {
ctx.loop_target.replace((test_bb, cont_bb))
};
let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(
ctx,
generator,
iter.custom.unwrap(),
)?;
if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) {
// setup
if is_iterable_range_expr {
let iter_val = iter_val.into_pointer_value();
let i = generator.gen_var_alloc(ctx, int32.into())?;
let user_i = generator.gen_store_target(ctx, target)?;
// Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc_named(ctx, int32.into(), "for.i.addr")?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let target_i = generator.gen_store_target_named(ctx, target, "for.target.addr")?;
let (start, stop, step) = destructure_range(ctx, iter_val);
ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init"));
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(test_bb);
let sign = ctx.builder.build_int_compare(
inkwell::IntPredicate::SGT,
step,
int32.const_zero(),
"sign",
);
// add and test
let tmp = ctx.builder.build_int_add(
ctx.builder.build_load(i, "i").into_int_value(),
step,
"start_loop",
);
ctx.builder.build_store(i, tmp);
ctx.builder.build_store(user_i, tmp);
// // if step > 0, continue when i < end
let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, stop, "cmp1");
// if step < 0, continue when i > end
let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, stop, "cmp2");
let pos = ctx.builder.build_and(sign, cmp1, "pos");
let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg");
ctx.builder.build_store(i, start);
// Pre-Loop Checks:
// - step == 0 -> ValueError
// - start < stop for step > 0 || start > stop for step < 0
// TODO: Generate step == 0 -> raise ValueError
ctx.builder.build_conditional_branch(
ctx.builder.build_or(pos, neg, "or"),
gen_in_range_check(ctx, start, stop, step),
body_bb,
orelse_bb,
);
ctx.builder.position_at_end(body_bb);
ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value());
gen_block(generator, ctx, body.iter())?;
// Test if next element is still in range
let next_i = ctx.builder.build_int_add(
ctx.builder.build_load(i, "").into_int_value(),
step,
"next_i",
);
let cond_cont_bb = ctx.ctx.append_basic_block(current, "for.cond.cont");
ctx.builder.build_conditional_branch(
gen_in_range_check(ctx, next_i, stop, step),
cond_cont_bb,
orelse_bb,
);
ctx.builder.position_at_end(cond_cont_bb);
ctx.builder.build_store(i, next_i);
} else {
let counter = generator.gen_var_alloc(ctx, size_t.into())?;
// counter = -1
@ -294,7 +325,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
let tmp = ctx.builder.build_load(counter, "i").into_int_value();
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc");
ctx.builder.build_store(counter, tmp);
let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, len, "cmp");
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, len, "cmp");
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb);
ctx.builder.position_at_end(body_bb);
@ -303,9 +334,9 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
.into_pointer_value();
let val = ctx.build_gep_and_load(arr_ptr, &[tmp]);
generator.gen_assign(ctx, target, val.into())?;
}
gen_block(generator, ctx, body.iter())?;
gen_block(generator, ctx, body.iter())?;
}
for (k, (_, _, counter)) in var_assignment.iter() {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
@ -315,7 +346,11 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
}
if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(test_bb);
if is_iterable_range_expr {
ctx.builder.build_unconditional_branch(body_bb);
} else {
ctx.builder.build_unconditional_branch(test_bb);
}
}
if !orelse.is_empty() {
@ -333,6 +368,12 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
}
}
// Clear test_bb if unused
if is_iterable_range_expr {
ctx.builder.position_at_end(test_bb);
ctx.builder.build_unreachable();
}
ctx.builder.position_at_end(cont_bb);
ctx.loop_target = loop_bb;
} else {
@ -858,7 +899,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
.builder
.build_load(exn_type.into_pointer_value(), "expected_id")
.into_int_value();
let result = ctx.builder.build_int_compare(EQ, actual_id, expected_id, "exncheck");
let result = ctx.builder.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck");
ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont);
dispatcher_end = dispatcher_cont;
} else {

View File

@ -0,0 +1,12 @@
# For Loop using a decreasing range() expression as its iterable
@extern
def output_int32(x: int32):
...
def run() -> int32:
i = 0
for i in range(10, 0, -1):
output_int32(i)
output_int32(i)
return 0

View File

@ -0,0 +1,14 @@
# For Loop using an range() expression as its iterable, additionally reassigning the target on each iteration
@extern
def output_int32(x: int32):
...
def run() -> int32:
i = 0
for i in range(10):
output_int32(i)
i = 0
output_int32(i)
output_int32(i)
return 0