1
0
forked from M-Labs/nac3

nac3core: use official implementation for len

This commit is contained in:
ychenfo 2021-12-13 04:02:30 +08:00 committed by Gitea
parent 2c6601d97c
commit b5637a04e9

View File

@ -1,5 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use inkwell::{IntPredicate, FloatPredicate, values::{BasicValueEnum, IntValue}}; use inkwell::{IntPredicate::{self, *}, FloatPredicate, values::IntValue};
use crate::{symbol_resolver::SymbolValue, codegen::expr::destructure_range}; use crate::{symbol_resolver::SymbolValue, codegen::expr::destructure_range};
use super::*; use super::*;
@ -570,7 +570,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
ty: arg_ty.0, ty: arg_ty.0,
default_value: None default_value: None
}], }],
ret: int32, ret: int64,
vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)].into_iter().collect(),
}))), }))),
var_id: vec![arg_ty.1], var_id: vec![arg_ty.1],
@ -582,13 +582,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let range_ty = ctx.primitives.range; let range_ty = ctx.primitives.range;
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1; let arg = args[0].1;
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
if ctx.unifier.unioned(arg_ty, range_ty) { if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = arg.into_pointer_value(); let arg = arg.into_pointer_value();
let (start, end, step) = destructure_range(ctx, arg); let (start, end, step) = destructure_range(ctx, arg);
Some(calculate_len_for_slice_range(ctx, start, end, step).into()) Some(calculate_len_for_slice_range(ctx, start, end, step).into())
} else { } else {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
Some(ctx.build_gep_and_load(arg.into_pointer_value(), &[zero, zero])) Some(ctx.build_gep_and_load(arg.into_pointer_value(), &[zero, zero]))
} }
}, },
@ -621,67 +621,74 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// equivalent code: // equivalent code:
// def length(start, end, step != 0): // def length(start, end, step != 0):
// diff = end - start // diff = end - start
// # if diff == 0 OR `diff` and `step` are of different signs, always zero // if diff > 0 and step > 0:
// if diff * step <= 0: // return ((diff - 1) // step) + 1
// return 0 // elif diff < 0 and step < 0:
// return ((diff + 1) // step) + 1
// else: // else:
// return ((abs(diff) - 1) // abs(step)) + 1 // return 0
pub fn calculate_len_for_slice_range<'ctx, 'a>( pub fn calculate_len_for_slice_range<'ctx, 'a>(
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
start: IntValue<'ctx>, start: IntValue<'ctx>,
end: IntValue<'ctx>, end: IntValue<'ctx>,
step: IntValue<'ctx>, step: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let int32 = ctx.ctx.i32_type(); let int64 = ctx.ctx.i64_type();
let int1 = ctx.ctx.bool_type(); let start = ctx.builder.build_int_s_extend(start, int64, "start");
let falze = int1.const_int(0, false); let end = ctx.builder.build_int_s_extend(end, int64, "end");
let abs_intrinsic = let step = ctx.builder.build_int_s_extend(step, int64, "step");
ctx.module.get_function("llvm.abs.i32").unwrap_or_else(|| {
let fn_type = int32.fn_type(&[int32.into(), int1.into()], false);
ctx.module.add_function("llvm.abs.i32", fn_type, None)
});
let diff = ctx.builder.build_int_sub(end, start, "diff"); let diff = ctx.builder.build_int_sub(end, start, "diff");
let test_mult = ctx.builder.build_int_mul(diff, step, "test_mult");
let test = let diff_pos = ctx.builder.build_int_compare(SGT, diff, int64.const_zero(), "diffpos");
ctx.builder.build_int_compare(inkwell::IntPredicate::SLE, test_mult, int32.const_zero(), "cmp"); let step_pos = ctx.builder.build_int_compare(SGT, step, int64.const_zero(), "steppos");
let test_1 = ctx.builder.build_and(diff_pos, step_pos, "bothpos");
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_bb = ctx.ctx.append_basic_block(current, "then"); let then_bb = ctx.ctx.append_basic_block(current, "then");
let else_bb = ctx.ctx.append_basic_block(current, "else"); let else_bb = ctx.ctx.append_basic_block(current, "else");
let then_bb_2 = ctx.ctx.append_basic_block(current, "then_2");
let else_bb_2 = ctx.ctx.append_basic_block(current, "else_2");
let cont_bb_2 = ctx.ctx.append_basic_block(current, "cont_2");
let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(test, then_bb, else_bb); ctx.builder.build_conditional_branch(test_1, then_bb, else_bb);
ctx.builder.position_at_end(then_bb); ctx.builder.position_at_end(then_bb);
let length_zero = int32.const_zero(); let length_pos = {
let diff_pos_min_1 = ctx.builder.build_int_sub(diff, int64.const_int(1, false), "diffminone");
let length_pos = ctx.builder.build_int_signed_div(diff_pos_min_1, step, "div");
ctx.builder.build_int_add(length_pos, int64.const_int(1, false), "add1")
};
ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(else_bb); ctx.builder.position_at_end(else_bb);
let diff = if let BasicValueEnum::IntValue(val) = ctx let phi_1 = {
.builder let diff_neg = ctx.builder.build_int_compare(SLT, diff, int64.const_zero(), "diffneg");
.build_call(abs_intrinsic, &[diff.into(), falze.into()], "absdiff") let step_neg = ctx.builder.build_int_compare(SLT, step, int64.const_zero(), "stepneg");
.try_as_basic_value() let test_2 = ctx.builder.build_and(diff_neg, step_neg, "bothneg");
.left()
.unwrap() { ctx.builder.build_conditional_branch(test_2, then_bb_2, else_bb_2);
val
} else { ctx.builder.position_at_end(then_bb_2);
unreachable!(); let length_neg = {
let diff_neg_add_1 = ctx.builder.build_int_add(diff, int64.const_int(1, false), "diffminone");
let length_neg = ctx.builder.build_int_signed_div(diff_neg_add_1, step, "div");
ctx.builder.build_int_add(length_neg, int64.const_int(1, false), "add1")
}; };
let diff = ctx.builder.build_int_sub(diff, int32.const_int(1, false), "diff"); ctx.builder.build_unconditional_branch(cont_bb_2);
let step = if let BasicValueEnum::IntValue(val) = ctx
.builder ctx.builder.position_at_end(else_bb_2);
.build_call(abs_intrinsic, &[step.into(), falze.into()], "absstep") let length_zero = int64.const_zero();
.try_as_basic_value() ctx.builder.build_unconditional_branch(cont_bb_2);
.left()
.unwrap() { ctx.builder.position_at_end(cont_bb_2);
val let phi_1 = ctx.builder.build_phi(int64, "lenphi1");
} else { phi_1.add_incoming(&[(&length_neg, then_bb_2), (&length_zero, else_bb_2)]);
unreachable!(); phi_1.as_basic_value().into_int_value()
}; };
let length = ctx.builder.build_int_signed_div(diff, step, "div");
let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1");
ctx.builder.build_unconditional_branch(cont_bb); ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
let phi = ctx.builder.build_phi(length_zero.get_type(), "lenphi"); let phi = ctx.builder.build_phi(int64, "lenphi");
phi.add_incoming(&[(&length_zero, then_bb), (&length, else_bb)]); phi.add_incoming(&[(&length_pos, then_bb), (&phi_1, cont_bb_2)]);
phi.as_basic_value().into_int_value() phi.as_basic_value().into_int_value()
} }