From 2c6601d97c111f07278b51272ffd1539847cad9f Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sun, 12 Dec 2021 05:52:22 +0800 Subject: [PATCH] nac3core: fix len on range with step of different sign --- nac3core/src/toplevel/builtins.rs | 106 ++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index fc0c000..c2e5be4 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,5 +1,5 @@ use std::cell::RefCell; -use inkwell::{IntPredicate, FloatPredicate, values::BasicValueEnum}; +use inkwell::{IntPredicate, FloatPredicate, values::{BasicValueEnum, IntValue}}; use crate::{symbol_resolver::SymbolValue, codegen::expr::destructure_range}; use super::*; @@ -573,7 +573,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ret: int32, vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), }))), - var_id: Default::default(), + var_id: vec![arg_ty.1], instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver: None, @@ -585,41 +585,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let int32 = ctx.ctx.i32_type(); let zero = int32.const_zero(); if ctx.unifier.unioned(arg_ty, range_ty) { - let int1 = ctx.ctx.bool_type(); - let one = int32.const_int(1, false); - let falze = int1.const_int(0, false); - let abs_intrinsic = - ctx.module.get_function("llvm.abs.i32").unwrap_or_else(|| { - let fn_type = int32.fn_type(&[int32.into(), int1.into()], false); - ctx.module.add_function("llvm.abs.i32", fn_type, None) - }); let arg = arg.into_pointer_value(); let (start, end, step) = destructure_range(ctx, arg); - let diff = ctx.builder.build_int_sub(end, start, "diff"); - let diff = if let BasicValueEnum::IntValue(val) = ctx - .builder - .build_call(abs_intrinsic, &[diff.into(), falze.into()], "absdiff") - .try_as_basic_value() - .left() - .unwrap() { - val - } else { - unreachable!(); - }; - let diff = ctx.builder.build_int_sub(diff, one, "diff"); - let step = if let BasicValueEnum::IntValue(val) = ctx - .builder - .build_call(abs_intrinsic, &[step.into(), falze.into()], "absstep") - .try_as_basic_value() - .left() - .unwrap() { - val - } else { - unreachable!(); - }; - let length = ctx.builder.build_int_signed_div(diff, step, "div"); - let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1"); - Some(length.into()) + Some(calculate_len_for_slice_range(ctx, start, end, step).into()) } else { Some(ctx.build_gep_and_load(arg.into_pointer_value(), &[zero, zero])) } @@ -648,4 +616,72 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "len", ] ) +} + +// equivalent code: +// def length(start, end, step != 0): +// diff = end - start +// # if diff == 0 OR `diff` and `step` are of different signs, always zero +// if diff * step <= 0: +// return 0 +// else: +// return ((abs(diff) - 1) // abs(step)) + 1 +pub fn calculate_len_for_slice_range<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + start: IntValue<'ctx>, + end: IntValue<'ctx>, + step: IntValue<'ctx>, +) -> IntValue<'ctx> { + let int32 = ctx.ctx.i32_type(); + let int1 = ctx.ctx.bool_type(); + let falze = int1.const_int(0, false); + let abs_intrinsic = + ctx.module.get_function("llvm.abs.i32").unwrap_or_else(|| { + let fn_type = int32.fn_type(&[int32.into(), int1.into()], false); + ctx.module.add_function("llvm.abs.i32", fn_type, None) + }); + let diff = ctx.builder.build_int_sub(end, start, "diff"); + let test_mult = ctx.builder.build_int_mul(diff, step, "test_mult"); + let test = + ctx.builder.build_int_compare(inkwell::IntPredicate::SLE, test_mult, int32.const_zero(), "cmp"); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let then_bb = ctx.ctx.append_basic_block(current, "then"); + let else_bb = ctx.ctx.append_basic_block(current, "else"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + ctx.builder.build_conditional_branch(test, then_bb, else_bb); + + ctx.builder.position_at_end(then_bb); + let length_zero = int32.const_zero(); + ctx.builder.build_unconditional_branch(cont_bb); + + ctx.builder.position_at_end(else_bb); + let diff = if let BasicValueEnum::IntValue(val) = ctx + .builder + .build_call(abs_intrinsic, &[diff.into(), falze.into()], "absdiff") + .try_as_basic_value() + .left() + .unwrap() { + val + } else { + unreachable!(); + }; + let diff = ctx.builder.build_int_sub(diff, int32.const_int(1, false), "diff"); + let step = if let BasicValueEnum::IntValue(val) = ctx + .builder + .build_call(abs_intrinsic, &[step.into(), falze.into()], "absstep") + .try_as_basic_value() + .left() + .unwrap() { + val + } else { + unreachable!(); + }; + let length = ctx.builder.build_int_signed_div(diff, step, "div"); + let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1"); + ctx.builder.build_unconditional_branch(cont_bb); + + ctx.builder.position_at_end(cont_bb); + let phi = ctx.builder.build_phi(length_zero.get_type(), "lenphi"); + phi.add_incoming(&[(&length_zero, then_bb), (&length, else_bb)]); + phi.as_basic_value().into_int_value() } \ No newline at end of file