forked from M-Labs/nac3
core: Preserve value of variable shadowed by for loop
Previously, the final value of the target expression would be one after the last element of the loop, which does not match Python's behavior. This commit fixes this problem while also preserving the last assigned value of the loop beyond the loop, matching Python's behavior.
This commit is contained in:
parent
6805253515
commit
e0de82993f
|
@ -77,7 +77,8 @@ pub struct CodeGenContext<'ctx, 'a> {
|
|||
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
|
||||
// stores the alloca for variables
|
||||
pub init_bb: BasicBlock<'ctx>,
|
||||
// the first one is the test_bb, and the second one is bb after the loop
|
||||
/// The header and exit basic blocks of a loop in this context. See
|
||||
/// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology.
|
||||
pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
|
||||
// unwind target bb
|
||||
pub unwind_target: Option<BasicBlock<'ctx>>,
|
||||
|
|
|
@ -13,8 +13,8 @@ use inkwell::{
|
|||
attributes::{Attribute, AttributeLoc},
|
||||
basic_block::BasicBlock,
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
|
||||
IntPredicate::EQ,
|
||||
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use nac3parser::ast::{
|
||||
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
|
||||
|
@ -107,7 +107,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
|
|||
);
|
||||
// handle negative index
|
||||
let is_negative = ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::SLT,
|
||||
IntPredicate::SLT,
|
||||
raw_index,
|
||||
generator.get_size_type(ctx.ctx).const_zero(),
|
||||
"is_neg",
|
||||
|
@ -120,7 +120,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>(
|
|||
// unsigned less than is enough, because negative index after adjustment is
|
||||
// bigger than the length (for unsigned cmp)
|
||||
let bound_check = ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::ULT,
|
||||
IntPredicate::ULT,
|
||||
index,
|
||||
len,
|
||||
"inbound",
|
||||
|
@ -214,6 +214,26 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a sequence of IR which checks whether [value] does not exceed the upper bound of the
|
||||
/// range as defined by [stop] and [step].
|
||||
///
|
||||
/// Note that the generated IR will **not** check whether value is part of the range or whether
|
||||
/// value exceeds the lower bound of the range (as evident by the missing `start` argument).
|
||||
///
|
||||
/// Returns an [IntValue] representing the result of whether the [value] is in the range.
|
||||
fn gen_in_range_check<'ctx, 'a>(
|
||||
ctx: &CodeGenContext<'ctx, 'a>,
|
||||
value: IntValue<'ctx>,
|
||||
stop: IntValue<'ctx>,
|
||||
step: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "");
|
||||
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value();
|
||||
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value();
|
||||
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
|
||||
}
|
||||
|
||||
pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
|
@ -234,49 +254,62 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
|
|||
// if there is no orelse, we just go to cont_bb
|
||||
let orelse_bb =
|
||||
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
|
||||
|
||||
// Whether the iterable is a range() expression
|
||||
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
|
||||
|
||||
// store loop bb information and restore it later
|
||||
let loop_bb = ctx.loop_target.replace((test_bb, cont_bb));
|
||||
let loop_bb = if is_iterable_range_expr {
|
||||
ctx.loop_target.replace((body_bb, cont_bb))
|
||||
} else {
|
||||
ctx.loop_target.replace((test_bb, cont_bb))
|
||||
};
|
||||
|
||||
let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
iter.custom.unwrap(),
|
||||
)?;
|
||||
if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) {
|
||||
// setup
|
||||
if is_iterable_range_expr {
|
||||
let iter_val = iter_val.into_pointer_value();
|
||||
// Internal variable for loop; Cannot be assigned
|
||||
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
||||
let user_i = generator.gen_store_target(ctx, target, Some("for.user_i.addr"))?;
|
||||
let (start, end, step) = destructure_range(ctx, iter_val);
|
||||
ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init"));
|
||||
ctx.builder.build_unconditional_branch(test_bb);
|
||||
ctx.builder.position_at_end(test_bb);
|
||||
let sign = ctx.builder.build_int_compare(
|
||||
inkwell::IntPredicate::SGT,
|
||||
step,
|
||||
int32.const_zero(),
|
||||
"sign",
|
||||
);
|
||||
// add and test
|
||||
let tmp = ctx.builder.build_int_add(
|
||||
ctx.builder.build_load(i, "i").into_int_value(),
|
||||
step,
|
||||
"start_loop",
|
||||
);
|
||||
ctx.builder.build_store(i, tmp);
|
||||
ctx.builder.build_store(user_i, tmp);
|
||||
// // if step > 0, continue when i < end
|
||||
let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1");
|
||||
// if step < 0, continue when i > end
|
||||
let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2");
|
||||
let pos = ctx.builder.build_and(sign, cmp1, "pos");
|
||||
let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg");
|
||||
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
||||
let target_i = generator.gen_store_target(ctx, target, Some("for.target.addr"))?;
|
||||
let (start, stop, step) = destructure_range(ctx, iter_val);
|
||||
|
||||
ctx.builder.build_store(i, start);
|
||||
|
||||
// Pre-Loop Checks:
|
||||
// - step == 0 -> ValueError
|
||||
// - start < stop for step > 0 || start > stop for step < 0
|
||||
// TODO: Generate step == 0 -> raise ValueError
|
||||
|
||||
ctx.builder.build_conditional_branch(
|
||||
ctx.builder.build_or(pos, neg, "or"),
|
||||
gen_in_range_check(ctx, start, stop, step),
|
||||
body_bb,
|
||||
orelse_bb,
|
||||
);
|
||||
|
||||
ctx.builder.position_at_end(body_bb);
|
||||
ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value());
|
||||
gen_block(generator, ctx, body.iter())?;
|
||||
|
||||
// Test if next element is still in range
|
||||
let next_i = ctx.builder.build_int_add(
|
||||
ctx.builder.build_load(i, "").into_int_value(),
|
||||
step,
|
||||
"next_i",
|
||||
);
|
||||
let cond_cont_bb = ctx.ctx.append_basic_block(current, "for.cond.cont");
|
||||
ctx.builder.build_conditional_branch(
|
||||
gen_in_range_check(ctx, next_i, stop, step),
|
||||
cond_cont_bb,
|
||||
orelse_bb,
|
||||
);
|
||||
|
||||
ctx.builder.position_at_end(cond_cont_bb);
|
||||
ctx.builder.build_store(i, next_i);
|
||||
} else {
|
||||
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("for.counter.addr"))?;
|
||||
// counter = -1
|
||||
|
@ -288,30 +321,39 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
|
|||
)
|
||||
.into_int_value();
|
||||
ctx.builder.build_unconditional_branch(test_bb);
|
||||
|
||||
ctx.builder.position_at_end(test_bb);
|
||||
let tmp = ctx.builder.build_load(counter, "i").into_int_value();
|
||||
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc");
|
||||
ctx.builder.build_store(counter, tmp);
|
||||
let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, len, "cmp");
|
||||
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, len, "cmp");
|
||||
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb);
|
||||
|
||||
ctx.builder.position_at_end(body_bb);
|
||||
let arr_ptr = ctx
|
||||
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero])
|
||||
.into_pointer_value();
|
||||
let val = ctx.build_gep_and_load(arr_ptr, &[tmp]);
|
||||
generator.gen_assign(ctx, target, val.into())?;
|
||||
|
||||
gen_block(generator, ctx, body.iter())?;
|
||||
}
|
||||
|
||||
gen_block(generator, ctx, body.iter())?;
|
||||
for (k, (_, _, counter)) in var_assignment.iter() {
|
||||
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
|
||||
if counter != counter2 {
|
||||
*static_val = None;
|
||||
}
|
||||
}
|
||||
|
||||
if !ctx.is_terminated() {
|
||||
ctx.builder.build_unconditional_branch(test_bb);
|
||||
if is_iterable_range_expr {
|
||||
ctx.builder.build_unconditional_branch(body_bb);
|
||||
} else {
|
||||
ctx.builder.build_unconditional_branch(test_bb);
|
||||
}
|
||||
}
|
||||
|
||||
if !orelse.is_empty() {
|
||||
ctx.builder.position_at_end(orelse_bb);
|
||||
gen_block(generator, ctx, orelse.iter())?;
|
||||
|
@ -319,12 +361,20 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
|
|||
ctx.builder.build_unconditional_branch(cont_bb);
|
||||
}
|
||||
}
|
||||
|
||||
for (k, (_, _, counter)) in var_assignment.iter() {
|
||||
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
|
||||
if counter != counter2 {
|
||||
*static_val = None;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear test_bb if unused
|
||||
if is_iterable_range_expr {
|
||||
ctx.builder.position_at_end(test_bb);
|
||||
ctx.builder.build_unreachable();
|
||||
}
|
||||
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
ctx.loop_target = loop_bb;
|
||||
} else {
|
||||
|
@ -850,7 +900,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
|
|||
.builder
|
||||
.build_load(exn_type.into_pointer_value(), "expected_id")
|
||||
.into_int_value();
|
||||
let result = ctx.builder.build_int_compare(EQ, actual_id, expected_id, "exncheck");
|
||||
let result = ctx.builder.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck");
|
||||
ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont);
|
||||
dispatcher_end = dispatcher_cont;
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# For Loop using a decreasing range() expression as its iterable
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
i = 0
|
||||
for i in range(10, 0, -1):
|
||||
output_int32(i)
|
||||
output_int32(i)
|
||||
return 0
|
|
@ -0,0 +1,14 @@
|
|||
# For Loop using an range() expression as its iterable, additionally reassigning the target on each iteration
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
def run() -> int32:
|
||||
i = 0
|
||||
for i in range(10):
|
||||
output_int32(i)
|
||||
i = 0
|
||||
output_int32(i)
|
||||
output_int32(i)
|
||||
return 0
|
Loading…
Reference in New Issue