From 4a65d82db5531e59e34a0b3c50a73bb508229dde Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 8 Jan 2022 22:16:55 +0800 Subject: [PATCH] introduce IRRT, implement power based on code by Yijia https://git.m-labs.hk/M-Labs/nac3/pulls/160 --- Cargo.lock | 1 + nac3artiq/src/lib.rs | 3 + nac3core/Cargo.toml | 3 + nac3core/build.rs | 54 +++++++++++++++++ nac3core/src/codegen/expr.rs | 78 ++++++++++++++++++------- nac3core/src/codegen/irrt/irrt.c | 25 ++++++++ nac3core/src/codegen/irrt/mod.rs | 48 +++++++++++++++ nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/stmt.rs | 27 ++------- nac3core/src/typecheck/magic_methods.rs | 3 +- 10 files changed, 200 insertions(+), 43 deletions(-) create mode 100644 nac3core/build.rs create mode 100644 nac3core/src/codegen/irrt/irrt.c create mode 100644 nac3core/src/codegen/irrt/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 399c90e7f..b4f54ad12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -513,6 +513,7 @@ dependencies = [ "nac3parser", "parking_lot", "rayon", + "regex", "test-case", ] diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index e1c87f76c..7c94d2d67 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -21,6 +21,7 @@ use parking_lot::{Mutex, RwLock}; use nac3core::{ codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, + codegen::irrt::load_irrt, symbol_resolver::SymbolResolver, toplevel::{composer::{TopLevelComposer, ComposerConfig}, DefinitionId, GenCall, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, @@ -588,6 +589,8 @@ impl Nac3 { main.link_in_module(other) .map_err(|err| exceptions::PyRuntimeError::new_err(err.to_string()))?; } + main.link_in_module(load_irrt(&context)) + .map_err(|err| exceptions::PyRuntimeError::new_err(err.to_string()))?; let mut function_iter = main.get_first_function(); while let Some(func) = function_iter { diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 64408af83..ab11b7b90 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -20,3 +20,6 @@ features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-l test-case = "1.2.0" indoc = "1.0" insta = "1.5" + +[build-dependencies] +regex = "1" diff --git a/nac3core/build.rs b/nac3core/build.rs new file mode 100644 index 000000000..c4dcfc8ef --- /dev/null +++ b/nac3core/build.rs @@ -0,0 +1,54 @@ +use regex::Regex; +use std::{ + env, + io::Write, + process::{Command, Stdio}, +}; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + const FILE: &str = "src/codegen/irrt/irrt.c"; + println!("cargo:rerun-if-changed={}", FILE); + const FLAG: &[&str] = &[ + FILE, + "-O3", + "-emit-llvm", + "-S", + "-Wall", + "-Wextra", + "-Wno-implicit-function-declaration", + "-o", + "-", + ]; + let output = Command::new("clang") + .args(FLAG) + .output() + .map(|o| { + assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); + o + }) + .unwrap(); + + let output = std::str::from_utf8(&output.stdout).unwrap(); + let mut filtered_output = String::with_capacity(output.len()); + + let regex_filter = regex::Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); + for f in regex_filter.captures_iter(output) { + assert!(f.len() == 1); + filtered_output.push_str(&f[0]); + filtered_output.push('\n'); + } + + let filtered_output = Regex::new("(#\\d+)|(, *![0-9A-Za-z.]+)|(![0-9A-Za-z.]+)|(!\".*?\")") + .unwrap() + .replace_all(&filtered_output, ""); + + let mut llvm_as = Command::new("llvm-as") + .stdin(Stdio::piped()) + .arg("-o") + .arg(&format!("{}/irrt.bc", out_dir)) + .spawn() + .unwrap(); + llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); + assert!(llvm_as.wait().unwrap().success()) +} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 29eb8dec3..6c8af87db 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3,7 +3,9 @@ use std::{collections::HashMap, convert::TryInto, iter::once}; use crate::{ codegen::{ concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, - get_llvm_type, CodeGenContext, CodeGenTask, + get_llvm_type, + irrt::integer_power, + CodeGenContext, CodeGenTask, }, symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{DefinitionId, TopLevelDef}, @@ -186,8 +188,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(), Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), + Operator::Pow => integer_power(self, lhs, rhs).into(), // special implementation? - Operator::Pow => unimplemented!(), Operator::MatMult => unreachable!(), } } @@ -205,6 +207,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } else { unreachable!() }; + let float = self.ctx.f64_type(); match op { Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").into(), Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").into(), @@ -215,7 +218,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let div = self.builder.build_float_div(lhs, rhs, "fdiv"); let floor_intrinsic = self.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let float = self.ctx.f64_type(); let fn_type = float.fn_type(&[float.into()], false); self.module.add_function("llvm.floor.f64", fn_type, None) }); @@ -225,6 +227,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .left() .unwrap() } + Operator::Pow => { + let pow_intrinsic = self.module.get_function("llvm.pow.f64").unwrap_or_else(|| { + let fn_type = float.fn_type(&[float.into(), float.into()], false); + self.module.add_function("llvm.pow.f64", fn_type, None) + }); + self.builder + .build_call(pow_intrinsic, &[lhs.into(), rhs.into()], "f_pow") + .try_as_basic_value() + .unwrap_left() + } // special implementation? _ => unimplemented!(), } @@ -630,6 +642,47 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( } } +pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + left: &Expr>, + op: &Operator, + right: &Expr>, +) -> ValueEnum<'ctx> { + let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); + let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); + let left = generator.gen_expr(ctx, left).unwrap().to_basic_value_enum(ctx, generator); + let right = generator.gen_expr(ctx, right).unwrap().to_basic_value_enum(ctx, generator); + + // we can directly compare the types, because we've got their representatives + // which would be unchanged until further unification, which we would never do + // when doing code generation for function instances + if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { + ctx.gen_int_ops(op, left, right) + } else if ty1 == ty2 && ctx.primitives.float == ty1 { + ctx.gen_float_ops(op, left, right) + } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { + // Pow is the only operator that would pass typecheck between float and int + assert!(*op == Operator::Pow); + // TODO: throw exception when rhs is out of i16 bound + // since llvm intrinsic only support to i16 for f64 + let i16_t = ctx.ctx.i16_type(); + let pow_intr = ctx.module.get_function("llvm.powi.f64.i16").unwrap_or_else(|| { + let f64_t = ctx.ctx.f64_type(); + let ty = f64_t.fn_type(&[f64_t.into(), i16_t.into()], false); + ctx.module.add_function("llvm.powi.f64.i16", ty, None) + }); + let right = ctx.builder.build_int_truncate(right.into_int_value(), i16_t, "r_pow"); + ctx.builder + .build_call(pow_intr, &[left.into(), right.into()], "f_pow_i") + .try_as_basic_value() + .unwrap_left() + } else { + unimplemented!() + } + .into() +} + pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, @@ -766,24 +819,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); phi.as_basic_value().into() } - ExprKind::BinOp { op, left, right } => { - let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); - let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left = generator.gen_expr(ctx, left).unwrap().to_basic_value_enum(ctx, generator); - let right = generator.gen_expr(ctx, right).unwrap().to_basic_value_enum(ctx, generator); - - // we can directly compare the types, because we've got their representatives - // which would be unchanged until further unification, which we would never do - // when doing code generation for function instances - if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(op, left, right) - } else if ty1 == ty2 && ctx.primitives.float == ty1 { - ctx.gen_float_ops(op, left, right) - } else { - unimplemented!() - } - .into() - } + ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right), ExprKind::UnaryOp { op, operand } => { let ty = ctx.unifier.get_representative(operand.custom.unwrap()); let val = generator.gen_expr(ctx, operand).unwrap().to_basic_value_enum(ctx, generator); diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c new file mode 100644 index 000000000..4db916fda --- /dev/null +++ b/nac3core/src/codegen/irrt/irrt.c @@ -0,0 +1,25 @@ +typedef _ExtInt(8) int8_t; +typedef unsigned _ExtInt(8) uint8_t; +typedef _ExtInt(32) int32_t; +typedef unsigned _ExtInt(32) uint32_t; +typedef _ExtInt(64) int64_t; +typedef unsigned _ExtInt(64) uint64_t; + +// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +// need to make sure `exp >= 0` before calling this function +#define DEF_INT_EXP(T) T __nac3_irrt_int_exp_##T( \ + T base, \ + T exp \ +) { \ + T res = (T)1; \ + /* repeated squaring method */ \ + do { \ + if (exp & 1) res *= base; /* for n odd */ \ + exp >>= 1; \ + base *= base; \ + } while (exp); \ + return res; \ +} \ + +DEF_INT_EXP(int32_t) +DEF_INT_EXP(int64_t) diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs new file mode 100644 index 000000000..af48bd203 --- /dev/null +++ b/nac3core/src/codegen/irrt/mod.rs @@ -0,0 +1,48 @@ +use super::CodeGenContext; +use inkwell::{ + context::Context, + attributes::AttributeLoc, + memory_buffer::MemoryBuffer, + module::Module, + values::IntValue, +}; + +pub fn load_irrt<'ctx>(ctx: &'ctx Context) -> Module<'ctx> { + let bitcode_buf = MemoryBuffer::create_from_memory_range( + include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), + "irrt_bitcode_buffer", + ); + let irrt_mod = Module::parse_bitcode_from_buffer(&bitcode_buf, ctx).unwrap(); + // add alwaysinline attributes to power function to help them get inlined + // alwaysinline enum = 1, see release/13.x/llvm/include/llvm/IR/Attributes.td + for symbol in &["__nac3_irrt_int_exp_int32_t", "__nac3_irrt_int_exp_int64_t"] { + let function = irrt_mod.get_function(symbol).unwrap(); + function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(1, 0)); + } + return irrt_mod; +} + +// repeated squaring method adapted from GNU Scientific Library: +// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +pub fn integer_power<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + base: IntValue<'ctx>, + exp: IntValue<'ctx>, +) -> IntValue<'ctx> { + let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width()) { + (32, 32) => "__nac3_irrt_int_exp_int32_t", + (64, 64) => "__nac3_irrt_int_exp_int64_t", + _ => unreachable!(), + }; + let base_type = base.get_type(); + let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { + let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); + ctx.module.add_function(symbol, fn_type, None) + }); + // TODO: throw exception when exp < 0 + ctx.builder + .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") + .try_as_basic_value() + .unwrap_left() + .into_int_value() +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1ff7c3cdb..427b3a0a3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -31,6 +31,7 @@ pub mod concrete_type; pub mod expr; mod generator; pub mod stmt; +pub mod irrt; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index f205e51ac..333ed351c 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -1,7 +1,10 @@ use super::{ super::symbol_resolver::ValueEnum, expr::destructure_range, CodeGenContext, CodeGenerator, }; -use crate::typecheck::typedef::Type; +use crate::{ + codegen::expr::gen_binop_expr, + typecheck::typedef::Type, +}; use inkwell::{ types::BasicTypeEnum, values::{BasicValue, BasicValueEnum, PointerValue}, @@ -417,26 +420,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>( StmtKind::For { .. } => return generator.gen_for(ctx, stmt), StmtKind::With { .. } => return generator.gen_with(ctx, stmt), StmtKind::AugAssign { target, op, value, .. } => { - let value = { - let ty1 = ctx.unifier.get_representative(target.custom.unwrap()); - let ty2 = ctx.unifier.get_representative(value.custom.unwrap()); - let left = - generator.gen_expr(ctx, target).unwrap().to_basic_value_enum(ctx, generator); - let right = - generator.gen_expr(ctx, value).unwrap().to_basic_value_enum(ctx, generator); - - // we can directly compare the types, because we've got their representatives - // which would be unchanged until further unification, which we would never do - // when doing code generation for function instances - if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(op, left, right) - } else if ty1 == ty2 && ctx.primitives.float == ty1 { - ctx.gen_float_ops(op, left, right) - } else { - unimplemented!() - } - }; - generator.gen_assign(ctx, target, value.into()); + let value = gen_binop_expr(generator, ctx, target, op, value); + generator.gen_assign(ctx, target, value); } _ => unimplemented!(), }; diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 43d0da598..8de79d298 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -183,7 +183,7 @@ pub fn impl_cmpop( } } -/// Add, Sub, Mult, Pow +/// Add, Sub, Mult pub fn impl_basic_arithmetic( unifier: &mut Unifier, store: &PrimitiveStore, @@ -201,6 +201,7 @@ pub fn impl_basic_arithmetic( ) } +/// Pow pub fn impl_pow( unifier: &mut Unifier, store: &PrimitiveStore,