From e0de82993f8366911130b302294482c535134bed Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 5 Sep 2023 12:10:52 +0800 Subject: [PATCH] core: Preserve value of variable shadowed by for loop Previously, the final value of the target expression would be one after the last element of the loop, which does not match Python's behavior. This commit fixes this problem while also preserving the last assigned value of the loop beyond the loop, matching Python's behavior. --- nac3core/src/codegen/mod.rs | 3 +- nac3core/src/codegen/stmt.rs | 124 +++++++++++++++------ nac3standalone/demo/src/loop_decr.py | 12 ++ nac3standalone/demo/src/loop_mutate_var.py | 14 +++ 4 files changed, 115 insertions(+), 38 deletions(-) create mode 100644 nac3standalone/demo/src/loop_decr.py create mode 100644 nac3standalone/demo/src/loop_mutate_var.py diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 72ba85ef7..3c213b348 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -77,7 +77,8 @@ pub struct CodeGenContext<'ctx, 'a> { pub const_strings: HashMap>, // stores the alloca for variables pub init_bb: BasicBlock<'ctx>, - // the first one is the test_bb, and the second one is bb after the loop + /// The header and exit basic blocks of a loop in this context. See + /// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology. pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, // unwind target bb pub unwind_target: Option>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 47ba743e3..5b9910deb 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -13,8 +13,8 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, - IntPredicate::EQ, + values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, + IntPredicate, }; use nac3parser::ast::{ Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, @@ -107,7 +107,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>( ); // handle negative index let is_negative = ctx.builder.build_int_compare( - inkwell::IntPredicate::SLT, + IntPredicate::SLT, raw_index, generator.get_size_type(ctx.ctx).const_zero(), "is_neg", @@ -120,7 +120,7 @@ pub fn gen_store_target<'ctx, 'a, G: CodeGenerator>( // unsigned less than is enough, because negative index after adjustment is // bigger than the length (for unsigned cmp) let bound_check = ctx.builder.build_int_compare( - inkwell::IntPredicate::ULT, + IntPredicate::ULT, index, len, "inbound", @@ -214,6 +214,26 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( Ok(()) } +/// Generates a sequence of IR which checks whether [value] does not exceed the upper bound of the +/// range as defined by [stop] and [step]. +/// +/// Note that the generated IR will **not** check whether value is part of the range or whether +/// value exceeds the lower bound of the range (as evident by the missing `start` argument). +/// +/// Returns an [IntValue] representing the result of whether the [value] is in the range. +fn gen_in_range_check<'ctx, 'a>( + ctx: &CodeGenContext<'ctx, 'a>, + value: IntValue<'ctx>, + stop: IntValue<'ctx>, + step: IntValue<'ctx>, +) -> IntValue<'ctx> { + let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), ""); + let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value(); + let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value(); + + ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") +} + pub fn gen_for<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -234,49 +254,62 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( // if there is no orelse, we just go to cont_bb let orelse_bb = if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; + + // Whether the iterable is a range() expression + let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); + // store loop bb information and restore it later - let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); + let loop_bb = if is_iterable_range_expr { + ctx.loop_target.replace((body_bb, cont_bb)) + } else { + ctx.loop_target.replace((test_bb, cont_bb)) + }; let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum( ctx, generator, iter.custom.unwrap(), )?; - if ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range) { - // setup + if is_iterable_range_expr { let iter_val = iter_val.into_pointer_value(); + // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; - let user_i = generator.gen_store_target(ctx, target, Some("for.user_i.addr"))?; - let (start, end, step) = destructure_range(ctx, iter_val); - ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); - ctx.builder.build_unconditional_branch(test_bb); - ctx.builder.position_at_end(test_bb); - let sign = ctx.builder.build_int_compare( - inkwell::IntPredicate::SGT, - step, - int32.const_zero(), - "sign", - ); - // add and test - let tmp = ctx.builder.build_int_add( - ctx.builder.build_load(i, "i").into_int_value(), - step, - "start_loop", - ); - ctx.builder.build_store(i, tmp); - ctx.builder.build_store(user_i, tmp); - // // if step > 0, continue when i < end - let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1"); - // if step < 0, continue when i > end - let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2"); - let pos = ctx.builder.build_and(sign, cmp1, "pos"); - let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg"); + // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed + let target_i = generator.gen_store_target(ctx, target, Some("for.target.addr"))?; + let (start, stop, step) = destructure_range(ctx, iter_val); + + ctx.builder.build_store(i, start); + + // Pre-Loop Checks: + // - step == 0 -> ValueError + // - start < stop for step > 0 || start > stop for step < 0 + // TODO: Generate step == 0 -> raise ValueError + ctx.builder.build_conditional_branch( - ctx.builder.build_or(pos, neg, "or"), + gen_in_range_check(ctx, start, stop, step), body_bb, orelse_bb, ); + ctx.builder.position_at_end(body_bb); + ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value()); + gen_block(generator, ctx, body.iter())?; + + // Test if next element is still in range + let next_i = ctx.builder.build_int_add( + ctx.builder.build_load(i, "").into_int_value(), + step, + "next_i", + ); + let cond_cont_bb = ctx.ctx.append_basic_block(current, "for.cond.cont"); + ctx.builder.build_conditional_branch( + gen_in_range_check(ctx, next_i, stop, step), + cond_cont_bb, + orelse_bb, + ); + + ctx.builder.position_at_end(cond_cont_bb); + ctx.builder.build_store(i, next_i); } else { let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("for.counter.addr"))?; // counter = -1 @@ -288,30 +321,39 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( ) .into_int_value(); ctx.builder.build_unconditional_branch(test_bb); + ctx.builder.position_at_end(test_bb); let tmp = ctx.builder.build_load(counter, "i").into_int_value(); let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc"); ctx.builder.build_store(counter, tmp); - let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, len, "cmp"); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, len, "cmp"); ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb); + ctx.builder.position_at_end(body_bb); let arr_ptr = ctx .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero]) .into_pointer_value(); let val = ctx.build_gep_and_load(arr_ptr, &[tmp]); generator.gen_assign(ctx, target, val.into())?; + + gen_block(generator, ctx, body.iter())?; } - gen_block(generator, ctx, body.iter())?; for (k, (_, _, counter)) in var_assignment.iter() { let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); if counter != counter2 { *static_val = None; } } + if !ctx.is_terminated() { - ctx.builder.build_unconditional_branch(test_bb); + if is_iterable_range_expr { + ctx.builder.build_unconditional_branch(body_bb); + } else { + ctx.builder.build_unconditional_branch(test_bb); + } } + if !orelse.is_empty() { ctx.builder.position_at_end(orelse_bb); gen_block(generator, ctx, orelse.iter())?; @@ -319,12 +361,20 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>( ctx.builder.build_unconditional_branch(cont_bb); } } + for (k, (_, _, counter)) in var_assignment.iter() { let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap(); if counter != counter2 { *static_val = None; } } + + // Clear test_bb if unused + if is_iterable_range_expr { + ctx.builder.position_at_end(test_bb); + ctx.builder.build_unreachable(); + } + ctx.builder.position_at_end(cont_bb); ctx.loop_target = loop_bb; } else { @@ -850,7 +900,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>( .builder .build_load(exn_type.into_pointer_value(), "expected_id") .into_int_value(); - let result = ctx.builder.build_int_compare(EQ, actual_id, expected_id, "exncheck"); + let result = ctx.builder.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck"); ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont); dispatcher_end = dispatcher_cont; } else { diff --git a/nac3standalone/demo/src/loop_decr.py b/nac3standalone/demo/src/loop_decr.py new file mode 100644 index 000000000..59afb1be7 --- /dev/null +++ b/nac3standalone/demo/src/loop_decr.py @@ -0,0 +1,12 @@ +# For Loop using a decreasing range() expression as its iterable + +@extern +def output_int32(x: int32): + ... + +def run() -> int32: + i = 0 + for i in range(10, 0, -1): + output_int32(i) + output_int32(i) + return 0 diff --git a/nac3standalone/demo/src/loop_mutate_var.py b/nac3standalone/demo/src/loop_mutate_var.py new file mode 100644 index 000000000..3ac5c2c12 --- /dev/null +++ b/nac3standalone/demo/src/loop_mutate_var.py @@ -0,0 +1,14 @@ +# For Loop using an range() expression as its iterable, additionally reassigning the target on each iteration + +@extern +def output_int32(x: int32): + ... + +def run() -> int32: + i = 0 + for i in range(10): + output_int32(i) + i = 0 + output_int32(i) + output_int32(i) + return 0