From 50230e61f3829881581b6ff554663b543eb82ad7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 6 Oct 2023 12:11:57 +0800 Subject: [PATCH] core: Simplify loop condition check for list comprehension --- nac3core/src/codegen/expr.rs | 20 +++++++++----------- nac3core/src/codegen/mod.rs | 28 ++++++++++++++++++++++++++++ nac3core/src/codegen/stmt.rs | 35 +++++------------------------------ 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 49909981..df7bb06a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, + gen_in_range_check, get_llvm_type, get_llvm_abi_type, irrt::*, @@ -963,11 +964,14 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( let i = generator.gen_store_target(ctx, target, Some("i.addr"))?; ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init")); - ctx.builder.build_unconditional_branch(test_bb); + + ctx.builder.build_conditional_branch( + gen_in_range_check(ctx, start, stop, step), + test_bb, + cont_bb, + ); ctx.builder.position_at_end(test_bb); - let sign = - ctx.builder.build_int_compare(IntPredicate::SGT, step, zero_32, "sign"); // add and test let tmp = ctx.builder.build_int_add( ctx.builder.build_load(i, "i").into_int_value(), @@ -975,14 +979,8 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( "start_loop", ); ctx.builder.build_store(i, tmp); - // if step > 0, continue when i < end - let cmp1 = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, stop, "cmp1"); - // if step < 0, continue when i > end - let cmp2 = ctx.builder.build_int_compare(IntPredicate::SGT, tmp, stop, "cmp2"); - let pos = ctx.builder.build_and(sign, cmp1, "pos"); - let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg"); ctx.builder.build_conditional_branch( - ctx.builder.build_or(pos, neg, "or"), + gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb, ); @@ -1001,7 +999,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value(); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 - ctx.builder.build_store(counter, size_t.const_int(u64::max_value(), true)); + ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)); ctx.builder.build_unconditional_branch(test_bb); ctx.builder.position_at_end(test_bb); diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index da5cee89..ddce7bcf 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -899,3 +899,31 @@ fn bool_to_i8<'ctx>( ), } } + +/// Generates a sequence of IR which checks whether `value` does not exceed the upper bound of the +/// range as defined by `stop` and `step`. +/// +/// Note that the generated IR will **not** check whether value is part of the range or whether +/// value exceeds the lower bound of the range (as evident by the missing `start` argument). +/// +/// The generated IR is equivalent to the following Rust code: +/// +/// ```rust,ignore +/// let sign = step > 0; +/// let (lo, hi) = if sign { (value, stop) } else { (stop, value) }; +/// let cmp = lo < hi; +/// ``` +/// +/// Returns an `i1` [IntValue] representing the result of whether the `value` is in the range. +fn gen_in_range_check<'ctx, 'a>( + ctx: &CodeGenContext<'ctx, 'a>, + value: IntValue<'ctx>, + stop: IntValue<'ctx>, + step: IntValue<'ctx>, +) -> IntValue<'ctx> { + let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), ""); + let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value(); + let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value(); + + ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index f737c877..e985818d 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -5,7 +5,10 @@ use super::{ CodeGenContext, CodeGenerator, }; use crate::{ - codegen::expr::gen_binop_expr, + codegen::{ + expr::gen_binop_expr, + gen_in_range_check, + }, toplevel::{DefinitionId, TopLevelDef}, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; @@ -13,7 +16,7 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, + values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, IntPredicate, }; use nac3parser::ast::{ @@ -232,34 +235,6 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>( Ok(()) } -/// Generates a sequence of IR which checks whether `value` does not exceed the upper bound of the -/// range as defined by `stop` and `step`. -/// -/// Note that the generated IR will **not** check whether value is part of the range or whether -/// value exceeds the lower bound of the range (as evident by the missing `start` argument). -/// -/// The generated IR is equivalent to the following Rust code: -/// -/// ```rust,ignore -/// let sign = step > 0; -/// let (lo, hi) = if sign { (value, stop) } else { (stop, value) }; -/// let cmp = lo < hi; -/// ``` -/// -/// Returns an `i1` [IntValue] representing the result of whether the `value` is in the range. -fn gen_in_range_check<'ctx, 'a>( - ctx: &CodeGenContext<'ctx, 'a>, - value: IntValue<'ctx>, - stop: IntValue<'ctx>, - step: IntValue<'ctx>, -) -> IntValue<'ctx> { - let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), ""); - let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value(); - let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value(); - - ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") -} - /// See [CodeGenerator::gen_for]. pub fn gen_for<'ctx, 'a, G: CodeGenerator>( generator: &mut G,