Compare commits

..

3 Commits

3 changed files with 55 additions and 68 deletions

View File

@ -3,6 +3,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{
codegen::{
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check,
get_llvm_type,
get_llvm_abi_type,
irrt::*,
@ -916,7 +917,9 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>(
ctx.builder.position_at_end(init_bb);
let Comprehension { target, iter, ifs, .. } = &generators[0];
let iter_val = generator.gen_expr(ctx, iter)?.unwrap().to_basic_value_enum(ctx, generator, iter.custom.unwrap())?;
let iter_val = generator.gen_expr(ctx, iter)?
.unwrap()
.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?;
let int32 = ctx.ctx.i32_type();
let size_t = generator.get_size_type(ctx.ctx);
let zero_size_t = size_t.const_zero();
@ -932,8 +935,8 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>(
if is_range {
let iter_val = iter_val.into_pointer_value();
let (start, end, step) = destructure_range(ctx, iter_val);
let diff = ctx.builder.build_int_sub(end, start, "diff");
let (start, stop, step) = destructure_range(ctx, iter_val);
let diff = ctx.builder.build_int_sub(stop, start, "diff");
// add 1 to the length as the value is rounded to zero
// the length may be 1 more than the actual length if the division is exact, but the
// length is a upper bound only anyway so it does not matter.
@ -942,46 +945,33 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>(
// in case length is non-positive
let is_valid =
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check");
let normal = ctx.ctx.append_basic_block(current, "listcomp.normal_list");
let empty = ctx.ctx.append_basic_block(current, "listcomp.empty_list");
let list_init = ctx.ctx.append_basic_block(current, "listcomp.list_init");
ctx.builder.build_conditional_branch(is_valid, normal, empty);
// normal: allocate a list
ctx.builder.position_at_end(normal);
let list_a = allocate_list(
generator,
ctx,
elem_ty,
let list_alloc_size = ctx.builder.build_select(
is_valid,
ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len"),
Some("listcomp"),
zero_size_t,
"listcomp.alloc_size"
);
ctx.builder.build_unconditional_branch(list_init);
ctx.builder.position_at_end(empty);
let list_b = allocate_list(
list = allocate_list(
generator,
ctx,
elem_ty,
zero_size_t,
Some("list_b")
list_alloc_size.into_int_value(),
Some("listcomp.addr")
);
ctx.builder.build_unconditional_branch(list_init);
ctx.builder.position_at_end(list_init);
let phi = ctx.builder.build_phi(list_a.get_type(), "phi");
phi.add_incoming(&[(&list_a, normal), (&list_b, empty)]);
list = phi.as_basic_value().into_pointer_value();
list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content"))
list_content = ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("listcomp.data.addr"))
.into_pointer_value();
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?;
ctx.builder.build_store(i, ctx.builder.build_int_sub(start, step, "start_init"));
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.build_conditional_branch(
gen_in_range_check(ctx, start, stop, step),
test_bb,
cont_bb,
);
ctx.builder.position_at_end(test_bb);
let sign =
ctx.builder.build_int_compare(IntPredicate::SGT, step, zero_32, "sign");
// add and test
let tmp = ctx.builder.build_int_add(
ctx.builder.build_load(i, "i").into_int_value(),
@ -989,14 +979,8 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>(
"start_loop",
);
ctx.builder.build_store(i, tmp);
// if step > 0, continue when i < end
let cmp1 = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, end, "cmp1");
// if step < 0, continue when i > end
let cmp2 = ctx.builder.build_int_compare(IntPredicate::SGT, tmp, end, "cmp2");
let pos = ctx.builder.build_and(sign, cmp1, "pos");
let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg");
ctx.builder.build_conditional_branch(
ctx.builder.build_or(pos, neg, "or"),
gen_in_range_check(ctx, tmp, stop, step),
body_bb,
cont_bb,
);
@ -1015,7 +999,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>(
ctx.build_gep_and_load(list, &[zero_size_t, zero_32], Some("list_content")).into_pointer_value();
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1
ctx.builder.build_store(counter, size_t.const_int(u64::max_value(), true));
ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true));
ctx.builder.build_unconditional_branch(test_bb);
ctx.builder.position_at_end(test_bb);

View File

@ -899,3 +899,31 @@ fn bool_to_i8<'ctx>(
),
}
}
/// Generates a sequence of IR which checks whether `value` does not exceed the upper bound of the
/// range as defined by `stop` and `step`.
///
/// Note that the generated IR will **not** check whether value is part of the range or whether
/// value exceeds the lower bound of the range (as evident by the missing `start` argument).
///
/// The generated IR is equivalent to the following Rust code:
///
/// ```rust,ignore
/// let sign = step > 0;
/// let (lo, hi) = if sign { (value, stop) } else { (stop, value) };
/// let cmp = lo < hi;
/// ```
///
/// Returns an `i1` [IntValue] representing the result of whether the `value` is in the range.
fn gen_in_range_check<'ctx, 'a>(
ctx: &CodeGenContext<'ctx, 'a>,
value: IntValue<'ctx>,
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "");
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value();
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
}

View File

@ -5,7 +5,10 @@ use super::{
CodeGenContext, CodeGenerator,
};
use crate::{
codegen::expr::gen_binop_expr,
codegen::{
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
@ -13,7 +16,7 @@ use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
IntPredicate,
};
use nac3parser::ast::{
@ -232,34 +235,6 @@ pub fn gen_assign<'ctx, 'a, G: CodeGenerator>(
Ok(())
}
/// Generates a sequence of IR which checks whether `value` does not exceed the upper bound of the
/// range as defined by `stop` and `step`.
///
/// Note that the generated IR will **not** check whether value is part of the range or whether
/// value exceeds the lower bound of the range (as evident by the missing `start` argument).
///
/// The generated IR is equivalent to the following Rust code:
///
/// ```rust,ignore
/// let sign = step > 0;
/// let (lo, hi) = if sign { (value, stop) } else { (stop, value) };
/// let cmp = lo < hi;
/// ```
///
/// Returns an `i1` [IntValue] representing the result of whether the `value` is in the range.
fn gen_in_range_check<'ctx, 'a>(
ctx: &CodeGenContext<'ctx, 'a>,
value: IntValue<'ctx>,
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "");
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value();
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
}
/// See [CodeGenerator::gen_for].
pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
generator: &mut G,