diff --git a/Cargo.lock b/Cargo.lock index 399c90e..b4f54ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -513,6 +513,7 @@ dependencies = [ "nac3parser", "parking_lot", "rayon", + "regex", "test-case", ] diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 64408af..48d0f48 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -20,3 +20,6 @@ features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-l test-case = "1.2.0" indoc = "1.0" insta = "1.5" + +[build-dependencies] +regex = "1" \ No newline at end of file diff --git a/nac3core/build.rs b/nac3core/build.rs new file mode 100644 index 0000000..c4dcfc8 --- /dev/null +++ b/nac3core/build.rs @@ -0,0 +1,54 @@ +use regex::Regex; +use std::{ + env, + io::Write, + process::{Command, Stdio}, +}; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + const FILE: &str = "src/codegen/irrt/irrt.c"; + println!("cargo:rerun-if-changed={}", FILE); + const FLAG: &[&str] = &[ + FILE, + "-O3", + "-emit-llvm", + "-S", + "-Wall", + "-Wextra", + "-Wno-implicit-function-declaration", + "-o", + "-", + ]; + let output = Command::new("clang") + .args(FLAG) + .output() + .map(|o| { + assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); + o + }) + .unwrap(); + + let output = std::str::from_utf8(&output.stdout).unwrap(); + let mut filtered_output = String::with_capacity(output.len()); + + let regex_filter = regex::Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); + for f in regex_filter.captures_iter(output) { + assert!(f.len() == 1); + filtered_output.push_str(&f[0]); + filtered_output.push('\n'); + } + + let filtered_output = Regex::new("(#\\d+)|(, *![0-9A-Za-z.]+)|(![0-9A-Za-z.]+)|(!\".*?\")") + .unwrap() + .replace_all(&filtered_output, ""); + + let mut llvm_as = Command::new("llvm-as") + .stdin(Stdio::piped()) + .arg("-o") + .arg(&format!("{}/irrt.bc", out_dir)) + .spawn() + .unwrap(); + llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); + assert!(llvm_as.wait().unwrap().success()) +} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 29eb8de..a9b270b 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3,7 +3,9 @@ use std::{collections::HashMap, convert::TryInto, iter::once}; use crate::{ codegen::{ concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, - get_llvm_type, CodeGenContext, CodeGenTask, + get_llvm_type, + irrt::*, + CodeGenContext, CodeGenTask, }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, @@ -186,8 +188,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(), Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), + Operator::Pow => integer_power(self, lhs, rhs).into(), // special implementation? - Operator::Pow => unimplemented!(), Operator::MatMult => unreachable!(), } } @@ -205,6 +207,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } else { unreachable!() }; + let float = self.ctx.f64_type(); match op { Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(), Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(), @@ -215,7 +218,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let div = self.builder.build_float_div(lhs, rhs, "fdiv"); let floor_intrinsic = self.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let float = self.ctx.f64_type(); let fn_type = float.fn_type(&[float.into()], false); self.module.add_function("llvm.floor.f64", fn_type, None) }); @@ -225,6 +227,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .left() .unwrap() } + Operator::Pow => { + let pow_intrinsic = self.module.get_function("llvm.pow.f64").unwrap_or_else(|| { + let fn_type = float.fn_type(&[float.into(), float.into()], false); + self.module.add_function("llvm.pow.f64", fn_type, None) + }); + self.builder + .build_call(pow_intrinsic, &[lhs.into(), rhs.into()], "f_pow") + .try_as_basic_value() + .unwrap_left() + } // special implementation? _ => unimplemented!(), } @@ -436,7 +448,8 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( if let Some(obj) = &obj { args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None }); } - let params = args.iter().map(|arg| ctx.get_llvm_type(generator, arg.ty).into()).collect_vec(); + let params = + args.iter().map(|arg| ctx.get_llvm_type(generator, arg.ty).into()).collect_vec(); let fun_ty = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { ctx.ctx.void_type().fn_type(¶ms, false) } else { @@ -630,6 +643,45 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( } } +pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + left: &Expr>, + op: &Operator, + right: &Expr>, +) -> ValueEnum<'ctx> { + let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); + let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); + let left = generator.gen_expr(ctx, left).unwrap().to_basic_value_enum(ctx, generator); + let right = generator.gen_expr(ctx, right).unwrap().to_basic_value_enum(ctx, generator); + + // we can directly compare the types, because we've got their representatives + // which would be unchanged until further unification, which we would never do + // when doing code generation for function instances + if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + ctx.gen_int_ops(op, left, right) + } else if ty1 == ty2 && ctx.primitives.float == ty1 { + ctx.gen_float_ops(op, left, right) + } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { + // TODO: throw exception when rhs is out of i16 bound + // since llvm intrinsic only support to i16 for f64 + let i16_t = ctx.ctx.i16_type(); + let pow_intr = ctx.module.get_function("llvm.powi.f64.i16").unwrap_or_else(|| { + let f64_t = ctx.ctx.f64_type(); + let ty = f64_t.fn_type(&[f64_t.into(), i16_t.into()], false); + ctx.module.add_function("llvm.powi.f64.i16", ty, None) + }); + let right = ctx.builder.build_int_truncate(right.into_int_value(), i16_t, "r_pow"); + ctx.builder + .build_call(pow_intr, &[left.into(), right.into()], "f_pow_i") + .try_as_basic_value() + .unwrap_left() + } else { + unimplemented!() + } + .into() +} + pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -766,24 +818,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); phi.as_basic_value().into() } - ExprKind::BinOp { op, left, right } => { - let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); - let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left = generator.gen_expr(ctx, left).unwrap().to_basic_value_enum(ctx, generator); - let right = generator.gen_expr(ctx, right).unwrap().to_basic_value_enum(ctx, generator); - - // we can directly compare the types, because we've got their representatives - // which would be unchanged until further unification, which we would never do - // when doing code generation for function instances - if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(op, left, right) - } else if ty1 == ty2 && ctx.primitives.float == ty1 { - ctx.gen_float_ops(op, left, right) - } else { - unimplemented!() - } - .into() - } + ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right), ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); let val = generator.gen_expr(ctx, operand).unwrap().to_basic_value_enum(ctx, generator); diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c new file mode 100644 index 0000000..227bc57 --- /dev/null +++ b/nac3core/src/codegen/irrt/irrt.c @@ -0,0 +1,41 @@ +typedef _ExtInt(8) int8_t; +typedef unsigned _ExtInt(8) uint8_t; +typedef _ExtInt(32) int32_t; +typedef unsigned _ExtInt(32) uint32_t; +typedef _ExtInt(64) int64_t; +typedef unsigned _ExtInt(64) uint64_t; + +# define MAX(a, b) (a > b ? a : b) +# define MIN(a, b) (a > b ? b : a) + +int32_t __nac3_irrt_range_slice_len(const int32_t start, const int32_t end, const int32_t step) { + int32_t diff = end - start; + if (diff > 0 && step > 0) { + return ((diff - 1) / step) + 1; + } else if (diff < 0 && step < 0) { + return ((diff + 1) / step) + 1; + } else { + return 0; + } +} + +// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +// need to make sure `exp >= 0` before calling this function +# define \ + DEF_INT_EXP(T) \ +T __nac3_irrt_int_exp_##T( \ + T base, \ + T exp \ +) { \ + T res = (T)1; \ + /* repeated squaring method */ \ + do { \ + if (exp & 1) res *= base; /* for n odd */ \ + exp >>= 1; \ + base *= base; \ + } while (exp); \ + return res; \ +} \ + +DEF_INT_EXP(int32_t) +DEF_INT_EXP(int64_t) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs new file mode 100644 index 0000000..2f8b5b3 --- /dev/null +++ b/nac3core/src/codegen/irrt/mod.rs @@ -0,0 +1,85 @@ +use super::*; +use inkwell::{ + attributes::AttributeLoc, + memory_buffer::MemoryBuffer, + module::{Linkage, Module}, + values::IntValue, +}; + +pub struct IrrtSymbolTable; +impl IrrtSymbolTable { + const LEN: &'static str = "__nac3_irrt_range_slice_len"; + const POWER_I32: &'static str = "__nac3_irrt_int_exp_int32_t"; + const POWER_I64: &'static str = "__nac3_irrt_int_exp_int64_t"; +} +pub const ALL_IRRT_SYMBOLS: &[&str] = + &[IrrtSymbolTable::LEN, IrrtSymbolTable::POWER_I32, IrrtSymbolTable::POWER_I64]; + +fn load_irrt<'ctx, 'a>(ctx: &CodeGenContext<'ctx, 'a>, fun: &str) -> FunctionValue<'ctx> { + let bitcode_buf = MemoryBuffer::create_from_memory_range( + include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), + "irrt_bitcode_buffer", + ); + let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx.ctx).unwrap(); + irrt_mod.set_data_layout(&ctx.module.get_data_layout()); + irrt_mod.set_triple(&ctx.module.get_triple()); + ctx.module.link_in_module(irrt_mod).unwrap(); + for f in ALL_IRRT_SYMBOLS { + let fun = ctx.module.get_function(f).unwrap(); + fun.set_linkage(Linkage::Private); + if f == &IrrtSymbolTable::POWER_I32 || f == &IrrtSymbolTable::POWER_I64 { + // add alwaysinline attributes to power function to help them get inlined + // alwaysinline enum = 1, see release/13.x/llvm/include/llvm/IR/Attributes.td + fun.add_attribute(AttributeLoc::Function, ctx.ctx.create_enum_attribute(1, 0)); + } + } + ctx.module.get_function(fun).unwrap() +} + +// equivalent code: +// def length(start, end, step != 0): +// diff = end - start +// if diff > 0 and step > 0: +// return ((diff - 1) // step) + 1 +// elif diff < 0 and step < 0: +// return ((diff + 1) // step) + 1 +// else: +// return 0 +pub fn calculate_len_for_slice_range<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + start: IntValue<'ctx>, + end: IntValue<'ctx>, + step: IntValue<'ctx>, +) -> IntValue<'ctx> { + const FUN_SYMBOL: &str = IrrtSymbolTable::LEN; + let len_func = + ctx.module.get_function(FUN_SYMBOL).unwrap_or_else(|| load_irrt(ctx, FUN_SYMBOL)); + + // TODO: throw exception when step == 0 + ctx.builder + .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") + .try_as_basic_value() + .left() + .unwrap() + .into_int_value() +} +// repeated squaring method adapted from GNU Scientific Library: +// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +pub fn integer_power<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + base: IntValue<'ctx>, + exp: IntValue<'ctx>, +) -> IntValue<'ctx> { + let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width()) { + (32, 32) => IrrtSymbolTable::POWER_I32, + (64, 64) => IrrtSymbolTable::POWER_I64, + _ => unreachable!(), + }; + let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| load_irrt(ctx, symbol)); + // TODO: throw exception when exp < 0 + ctx.builder + .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") + .try_as_basic_value() + .unwrap_left() + .into_int_value() +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1ff7c3c..427b3a0 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -31,6 +31,7 @@ pub mod concrete_type; pub mod expr; mod generator; pub mod stmt; +pub mod irrt; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index f205e51..333ed35 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,7 +1,10 @@ use super::{ super::symbol_resolver::ValueEnum, expr::destructure_range, CodeGenContext, CodeGenerator, }; -use crate::typecheck::typedef::Type; +use crate::{ + codegen::expr::gen_binop_expr, + typecheck::typedef::Type, +}; use inkwell::{ types::BasicTypeEnum, values::{BasicValue, BasicValueEnum, PointerValue}, @@ -417,26 +420,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( StmtKind::For { .. } => return generator.gen_for(ctx, stmt), StmtKind::With { .. } => return generator.gen_with(ctx, stmt), StmtKind::AugAssign { target, op, value, .. } => { - let value = { - let ty1 = ctx.unifier.get_representative(target.custom.unwrap()); - let ty2 = ctx.unifier.get_representative(value.custom.unwrap()); - let left = - generator.gen_expr(ctx, target).unwrap().to_basic_value_enum(ctx, generator); - let right = - generator.gen_expr(ctx, value).unwrap().to_basic_value_enum(ctx, generator); - - // we can directly compare the types, because we've got their representatives - // which would be unchanged until further unification, which we would never do - // when doing code generation for function instances - if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(op, left, right) - } else if ty1 == ty2 && ctx.primitives.float == ty1 { - ctx.gen_float_ops(op, left, right) - } else { - unimplemented!() - } - }; - generator.gen_assign(ctx, target, value.into()); + let value = gen_binop_expr(generator, ctx, target, op, value); + generator.gen_assign(ctx, target, value); } _ => unimplemented!(), }; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index bc85b0b..38286f0 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,7 +1,10 @@ -use std::cell::RefCell; -use inkwell::{IntPredicate::{self, *}, FloatPredicate, values::IntValue}; -use crate::{symbol_resolver::SymbolValue, codegen::expr::destructure_range}; use super::*; +use crate::{ + codegen::{expr::destructure_range, irrt::*}, + symbol_resolver::SymbolValue, +}; +use inkwell::{FloatPredicate, IntPredicate}; +use std::cell::RefCell; type BuiltinInfo = ( Vec<(Arc>, Option)>, @@ -17,7 +20,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let string = primitives.0.str; let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean]); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); - let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( 0, @@ -622,78 +624,3 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ] ) } - -// equivalent code: -// def length(start, end, step != 0): -// diff = end - start -// if diff > 0 and step > 0: -// return ((diff - 1) // step) + 1 -// elif diff < 0 and step < 0: -// return ((diff + 1) // step) + 1 -// else: -// return 0 -pub fn calculate_len_for_slice_range<'ctx, 'a>( - ctx: &mut CodeGenContext<'ctx, 'a>, - start: IntValue<'ctx>, - end: IntValue<'ctx>, - step: IntValue<'ctx>, -) -> IntValue<'ctx> { - let int32 = ctx.ctx.i32_type(); - let start = ctx.builder.build_int_s_extend(start, int32, "start"); - let end = ctx.builder.build_int_s_extend(end, int32, "end"); - let step = ctx.builder.build_int_s_extend(step, int32, "step"); - let diff = ctx.builder.build_int_sub(end, start, "diff"); - - let diff_pos = ctx.builder.build_int_compare(SGT, diff, int32.const_zero(), "diffpos"); - let step_pos = ctx.builder.build_int_compare(SGT, step, int32.const_zero(), "steppos"); - let test_1 = ctx.builder.build_and(diff_pos, step_pos, "bothpos"); - - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let then_bb = ctx.ctx.append_basic_block(current, "then"); - let else_bb = ctx.ctx.append_basic_block(current, "else"); - let then_bb_2 = ctx.ctx.append_basic_block(current, "then_2"); - let else_bb_2 = ctx.ctx.append_basic_block(current, "else_2"); - let cont_bb_2 = ctx.ctx.append_basic_block(current, "cont_2"); - let cont_bb = ctx.ctx.append_basic_block(current, "cont"); - ctx.builder.build_conditional_branch(test_1, then_bb, else_bb); - - ctx.builder.position_at_end(then_bb); - let length_pos = { - let diff_pos_min_1 = ctx.builder.build_int_sub(diff, int32.const_int(1, false), "diffminone"); - let length_pos = ctx.builder.build_int_signed_div(diff_pos_min_1, step, "div"); - ctx.builder.build_int_add(length_pos, int32.const_int(1, false), "add1") - }; - ctx.builder.build_unconditional_branch(cont_bb); - - ctx.builder.position_at_end(else_bb); - let phi_1 = { - let diff_neg = ctx.builder.build_int_compare(SLT, diff, int32.const_zero(), "diffneg"); - let step_neg = ctx.builder.build_int_compare(SLT, step, int32.const_zero(), "stepneg"); - let test_2 = ctx.builder.build_and(diff_neg, step_neg, "bothneg"); - - ctx.builder.build_conditional_branch(test_2, then_bb_2, else_bb_2); - - ctx.builder.position_at_end(then_bb_2); - let length_neg = { - let diff_neg_add_1 = ctx.builder.build_int_add(diff, int32.const_int(1, false), "diffminone"); - let length_neg = ctx.builder.build_int_signed_div(diff_neg_add_1, step, "div"); - ctx.builder.build_int_add(length_neg, int32.const_int(1, false), "add1") - }; - ctx.builder.build_unconditional_branch(cont_bb_2); - - ctx.builder.position_at_end(else_bb_2); - let length_zero = int32.const_zero(); - ctx.builder.build_unconditional_branch(cont_bb_2); - - ctx.builder.position_at_end(cont_bb_2); - let phi_1 = ctx.builder.build_phi(int32, "lenphi1"); - phi_1.add_incoming(&[(&length_neg, then_bb_2), (&length_zero, else_bb_2)]); - phi_1.as_basic_value().into_int_value() - }; - ctx.builder.build_unconditional_branch(cont_bb); - - ctx.builder.position_at_end(cont_bb); - let phi = ctx.builder.build_phi(int32, "lenphi"); - phi.add_incoming(&[(&length_pos, then_bb), (&phi_1, cont_bb_2)]); - phi.as_basic_value().into_int_value() -} diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 43d0da5..f2989d0 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -183,7 +183,7 @@ pub fn impl_cmpop( } } -/// Add, Sub, Mult, Pow +/// Add, Sub, Mult pub fn impl_basic_arithmetic( unifier: &mut Unifier, store: &PrimitiveStore, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 702c03a..05f131d 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -268,6 +268,7 @@ fn main() { let builder = PassManagerBuilder::create(); builder.set_optimization_level(OptimizationLevel::Aggressive); let passes = PassManager::create(()); + builder.set_inliner_with_threshold(255); builder.populate_module_pass_manager(&passes); passes.run_on(module);