From 82fdb02d13acf0e351f4e6c13c934a8f195479f8 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 22 Feb 2024 01:47:26 +0800 Subject: [PATCH] core: Extract LLVM intrinsic functions to their functions --- nac3artiq/src/codegen.rs | 53 +-- nac3core/src/codegen/classes.rs | 21 +- nac3core/src/codegen/expr.rs | 162 ++----- nac3core/src/codegen/llvm_intrinsics.rs | 562 ++++++++++++++++++++++++ nac3core/src/codegen/mod.rs | 1 + nac3core/src/toplevel/builtins.rs | 346 +++++---------- nac3core/src/toplevel/numpy.rs | 44 +- 7 files changed, 720 insertions(+), 469 deletions(-) create mode 100644 nac3core/src/codegen/llvm_intrinsics.rs diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index f9da2e9b..17d2593d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,6 +1,7 @@ use nac3core::{ codegen::{ expr::gen_call, + llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_with}, CodeGenContext, CodeGenerator, }, @@ -15,7 +16,7 @@ use inkwell::{ context::Context, module::Linkage, types::IntType, - values::{BasicValueEnum, CallSiteValue}, + values::BasicValueEnum, AddressSpace, }; @@ -29,7 +30,6 @@ use std::{ hash::{Hash, Hasher}, sync::Arc, }; -use itertools::Either; /// The parallelism mode within a block. #[derive(Copy, Clone, Eq, PartialEq)] @@ -133,20 +133,12 @@ impl<'a> ArtiqCodeGenerator<'a> { .unwrap() .to_basic_value_enum(ctx, self, end.custom.unwrap())?; let now = self.timeline.emit_now_mu(ctx); - let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { - let i64 = ctx.ctx.i64_type(); - ctx.module.add_function( - "llvm.smax.i64", - i64.fn_type(&[i64.into(), i64.into()], false), - None, - ) - }); - let max = ctx - .builder - .build_call(smax, &[old_end.into(), now.into()], "smax") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let max = call_int_smax( + ctx, + old_end.into_int_value(), + now.into_int_value(), + Some("smax") + ); let end_store = self.gen_store_target( ctx, &end, @@ -471,18 +463,7 @@ fn rpc_codegen_callback_fn<'ctx>( let arg_length = args.len() + usize::from(obj.is_some()); - let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| { - ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None) - }); - let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| { - ctx.module.add_function( - "llvm.stackrestore", - ctx.ctx.void_type().fn_type(&[ptr_type.into()], false), - None, - ) - }); - - let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack").unwrap(); + let stackptr = call_stacksave(ctx, Some("rpc.stack")); let args_ptr = ctx.builder .build_array_alloca( ptr_type, @@ -558,13 +539,7 @@ fn rpc_codegen_callback_fn<'ctx>( .unwrap(); // reclaim stack space used by arguments - ctx.builder - .build_call( - stackrestore, - &[stackptr.try_as_basic_value().unwrap_left().into()], - "rpc.stackrestore", - ) - .unwrap(); + call_stackrestore(ctx, stackptr); // -- receive value: // T result = { @@ -624,13 +599,7 @@ fn rpc_codegen_callback_fn<'ctx>( let result = ctx.builder.build_load(slot, "rpc.result").unwrap(); if need_load { - ctx.builder - .build_call( - stackrestore, - &[stackptr.try_as_basic_value().unwrap_left().into()], - "rpc.stackrestore", - ) - .unwrap(); + call_stackrestore(ctx, stackptr); } Ok(Some(result)) } diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 4394ce02..34463ef4 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1,13 +1,13 @@ use inkwell::{ IntPredicate, types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, - values::{ArrayValue, BasicValueEnum, CallSiteValue, IntValue, PointerValue}, + values::{ArrayValue, BasicValueEnum, IntValue, PointerValue}, }; -use itertools::Either; use crate::codegen::{ CodeGenContext, CodeGenerator, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const}, + llvm_intrinsics::call_int_umin, stmt::gen_for_callback, }; @@ -924,22 +924,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> { let indices_len = indices.load_size(ctx, None); let ndarray_len = self.0.load_ndims(ctx); - let min_fn_name = format!("llvm.umin.i{}", llvm_usize.get_bit_width()); - let min_fn = ctx.module.get_function(min_fn_name.as_str()).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_usize.into(), llvm_usize.into()], - false - ); - ctx.module.add_function(min_fn_name.as_str(), fn_type, None) - }); - - let len = ctx - .builder - .build_call(min_fn, &[indices_len.into(), ndarray_len.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); + let len = call_int_umin(ctx, indices_len, ndarray_len, None); let i = ctx.builder.build_load(i_addr, "") .map(BasicValueEnum::into_int_value) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 9f07247e..e0df6f0b 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -8,6 +8,7 @@ use crate::{ get_llvm_type, get_llvm_abi_type, irrt::*, + llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, stmt::{gen_raise, gen_var}, CodeGenContext, CodeGenTask, }, @@ -30,7 +31,7 @@ use nac3parser::ast::{ self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, }; -use super::{CodeGenerator, need_sret}; +use super::{CodeGenerator, llvm_intrinsics::call_memcpy_generic, need_sret}; pub fn get_subst_key( unifier: &mut Unifier, @@ -371,7 +372,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else { unreachable!() }; - let float = self.ctx.f64_type(); match op { Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(), Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").map(Into::into).unwrap(), @@ -380,28 +380,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").map(Into::into).unwrap(), Operator::FloorDiv => { let div = self.builder.build_float_div(lhs, rhs, "fdiv").unwrap(); - let floor_intrinsic = - self.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let fn_type = float.fn_type(&[float.into()], false); - self.module.add_function("llvm.floor.f64", fn_type, None) - }); - self.builder - .build_call(floor_intrinsic, &[div.into()], "floor") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap() - } - Operator::Pow => { - let pow_intrinsic = self.module.get_function("llvm.pow.f64").unwrap_or_else(|| { - let fn_type = float.fn_type(&[float.into(), float.into()], false); - self.module.add_function("llvm.pow.f64", fn_type, None) - }); - self.builder - .build_call(pow_intrinsic, &[lhs.into(), rhs.into()], "f_pow") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap() + call_float_floor(self, div, Some("floor")).into() } + Operator::Pow => call_float_pow(self, lhs, rhs, Some("f_pow")).into(), // special implementation? _ => unimplemented!(), } @@ -585,24 +566,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ) { let i1 = self.ctx.bool_type(); let i1_true = i1.const_all_ones(); - let expect_fun = self.module.get_function("llvm.expect.i1").unwrap_or_else(|| { - self.module.add_function( - "llvm.expect.i1", - i1.fn_type(&[i1.into(), i1.into()], false), - None, - ) - }); // we assume that the condition is most probably true, so the normal path is the most // probable path // even if this assumption is violated, it does not matter as exception unwinding is // slow anyway... - let cond = self - .builder - .build_call(expect_fun, &[cond.into(), i1_true.into()], "expect") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); + let cond = call_expect(self, cond, i1_true, Some("expect")); let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let then_block = self.ctx.append_basic_block(current_fun, "succ"); let exn_block = self.ctx.append_basic_block(current_fun, "fail"); @@ -1150,17 +1118,12 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { // Pow is the only operator that would pass typecheck between float and int assert_eq!(*op, Operator::Pow); - let i32_t = ctx.ctx.i32_type(); - let pow_intr = ctx.module.get_function("llvm.powi.f64.i32").unwrap_or_else(|| { - let f64_t = ctx.ctx.f64_type(); - let ty = f64_t.fn_type(&[f64_t.into(), i32_t.into()], false); - ctx.module.add_function("llvm.powi.f64.i32", ty, None) - }); - let res = ctx.builder - .build_call(pow_intr, &[left_val.into(), right_val.into()], "f_pow_i") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let res = call_float_powi( + ctx, + left_val.into_float_value(), + right_val.into_int_value(), + Some("f_pow_i") + ); Ok(Some(res.into())) } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap()); @@ -1229,11 +1192,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( v: NDArrayValue<'ctx>, slice: &Expr>, ) -> Result>, String> { - let llvm_void = ctx.ctx.void_type(); let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i8 = ctx.ctx.i8_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { unreachable!() @@ -1333,24 +1293,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let ndarray_num_dims = ndarray.load_ndims(ctx); ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); - let memcpy_fn_name = format!( - "llvm.memcpy.p0i8.p0i8.i{}", - generator.get_size_type(ctx.ctx).get_bit_width(), - ); - let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[ - llvm_pi8.into(), - llvm_pi8.into(), - llvm_usize.into(), - llvm_i1.into(), - ], - false, - ); - - ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None) - }); - let ndarray_num_dims = ndarray.load_ndims(ctx); let v_dims_src_ptr = v.get_dims().ptr_offset( ctx, @@ -1358,37 +1300,16 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( llvm_usize.const_int(1, false), None, ); - ctx.builder.build_call( - memcpy_fn, - &[ - ctx.builder - .build_bitcast( - ndarray.get_dims().get_ptr(ctx), - llvm_pi8, - "", - ) - .map(Into::into) - .unwrap(), - ctx.builder - .build_bitcast( - v_dims_src_ptr, - llvm_pi8, - "", - ) - .map(Into::into) - .unwrap(), - ctx.builder - .build_int_mul( - ndarray_num_dims, - llvm_usize.size_of(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero().into(), - ], - "", - ).unwrap(); + call_memcpy_generic( + ctx, + ndarray.get_dims().get_ptr(ctx), + v_dims_src_ptr, + ctx.builder + .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") + .map(Into::into) + .unwrap(), + llvm_i1.const_zero(), + ); let ndarray_num_elems = call_ndarray_calc_size( generator, @@ -1404,37 +1325,16 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ctx.ctx.i32_type().const_array(&[index]), None ); - ctx.builder.build_call( - memcpy_fn, - &[ - ctx.builder - .build_bitcast( - ndarray.get_data().get_ptr(ctx), - llvm_pi8, - "", - ) - .map(Into::into) - .unwrap(), - ctx.builder - .build_bitcast( - v_data_src_ptr, - llvm_pi8, - "", - ) - .map(Into::into) - .unwrap(), - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero().into(), - ], - "", - ).unwrap(); + call_memcpy_generic( + ctx, + ndarray.get_data().get_ptr(ctx), + v_data_src_ptr, + ctx.builder + .build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "") + .map(Into::into) + .unwrap(), + llvm_i1.const_zero(), + ); Ok(Some(v.get_ptr().into())) } diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs new file mode 100644 index 00000000..3be6c7ef --- /dev/null +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -0,0 +1,562 @@ +use inkwell::AddressSpace; +use inkwell::context::Context; +use inkwell::types::AnyTypeEnum::IntType; +use inkwell::types::FloatType; +use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}; +use itertools::Either; +use crate::codegen::CodeGenContext; + +/// Returns the string representation for the floating-point type `ft` when used in intrinsic +/// functions. +fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { + // Standard LLVM floating-point types + if ft == ctx.f16_type() { + return "f16" + } + if ft == ctx.f32_type() { + return "f32" + } + if ft == ctx.f64_type() { + return "f64" + } + if ft == ctx.f128_type() { + return "f128" + } + + // Non-standard floating-point types + if ft == ctx.x86_f80_type() { + return "f80" + } + if ft == ctx.ppc_f128_type() { + return "ppcf128" + } + + unreachable!() +} + +/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic) +/// intrinsic. +pub fn call_stacksave<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + name: Option<&str>, +) -> PointerValue<'ctx> { + const FN_NAME: &str = "llvm.stacksave"; + + let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default()); + + let fn_type = llvm_p0i8.fn_type(&[], false); + + ctx.module.add_function(FN_NAME, fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_pointer_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the +/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic. +pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) { + const FN_NAME: &str = "llvm.stackrestore"; + + let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let llvm_void = ctx.ctx.void_type(); + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default()); + + let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false); + + ctx.module.add_function(FN_NAME, fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[ptr.into()], "") + .unwrap(); +} + +/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic. +/// +/// * `src` - The value for which the absolute value is to be returned. +/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`. +pub fn call_int_abs<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + src: IntValue<'ctx>, + is_int_min_poison: IntValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1); + debug_assert!(is_int_min_poison.is_const()); + + let llvm_src_t = src.get_type(); + + let fn_name = format!("llvm.abs.i{}", llvm_src_t.get_bit_width()); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let llvm_i1 = ctx.ctx.bool_type(); + + let fn_type = llvm_src_t.fn_type(&[llvm_src_t.into(), llvm_i1.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[src.into(), is_int_min_poison.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.smax`](https://llvm.org/docs/LangRef.html#llvm-smax-intrinsic) intrinsic. +pub fn call_int_smax<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + a: IntValue<'ctx>, + b: IntValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width()); + + let llvm_int_t = a.get_type(); + + let fn_name = format!("llvm.smax.i{}", llvm_int_t.get_bit_width()); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.smin`](https://llvm.org/docs/LangRef.html#llvm-smin-intrinsic) intrinsic. +pub fn call_int_smin<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + a: IntValue<'ctx>, + b: IntValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width()); + + let llvm_int_t = a.get_type(); + + let fn_name = format!("llvm.smin.i{}", llvm_int_t.get_bit_width()); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.umax`](https://llvm.org/docs/LangRef.html#llvm-umax-intrinsic) intrinsic. +pub fn call_int_umax<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + a: IntValue<'ctx>, + b: IntValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width()); + + let llvm_int_t = a.get_type(); + + let fn_name = format!("llvm.umax.i{}", llvm_int_t.get_bit_width()); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.umin`](https://llvm.org/docs/LangRef.html#llvm-umin-intrinsic) intrinsic. +pub fn call_int_umin<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + a: IntValue<'ctx>, + b: IntValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width()); + + let llvm_int_t = a.get_type(); + + let fn_name = format!("llvm.umin.i{}", llvm_int_t.get_bit_width()); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic. +/// +/// * `dest` - The pointer to the destination. Must be a pointer to an integer type. +/// * `src` - The pointer to the source. Must be a pointer to an integer type. +/// * `len` - The number of bytes to copy. +/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`. +pub fn call_memcpy<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dest: PointerValue<'ctx>, + src: PointerValue<'ctx>, + len: IntValue<'ctx>, + is_volatile: IntValue<'ctx>, +) { + debug_assert!(dest.get_type().get_element_type().is_int_type()); + debug_assert!(src.get_type().get_element_type().is_int_type()); + debug_assert_eq!( + dest.get_type().get_element_type().into_int_type().get_bit_width(), + src.get_type().get_element_type().into_int_type().get_bit_width(), + ); + debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64)); + debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1); + + let llvm_dest_t = dest.get_type(); + let llvm_src_t = src.get_type(); + let llvm_len_t = len.get_type(); + + let fn_name = format!( + "llvm.memcpy.p0i{}.p0i{}.i{}", + llvm_dest_t.get_element_type().into_int_type().get_bit_width(), + llvm_src_t.get_element_type().into_int_type().get_bit_width(), + llvm_len_t.get_bit_width(), + ); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let llvm_void = ctx.ctx.void_type(); + + let fn_type = llvm_void.fn_type( + &[ + llvm_dest_t.into(), + llvm_src_t.into(), + llvm_len_t.into(), + is_volatile.get_type().into(), + ], + false, + ); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "") + .unwrap(); +} + +/// Invokes the `llvm.memcpy` intrinsic. +/// +/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is +/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`. +pub fn call_memcpy_generic<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dest: PointerValue<'ctx>, + src: PointerValue<'ctx>, + len: IntValue<'ctx>, + is_volatile: IntValue<'ctx>, +) { + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default()); + + let dest_elem_t = dest.get_type().get_element_type(); + let src_elem_t = src.get_type().get_element_type(); + + let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) { + dest + } else { + ctx.builder + .build_bitcast(dest, llvm_p0i8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap() + }; + let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) { + src + } else { + ctx.builder + .build_bitcast(src, llvm_p0i8, "") + .map(BasicValueEnum::into_pointer_value) + .unwrap() + }; + + call_memcpy(ctx, dest, src, len, is_volatile); +} + +/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic. +pub fn call_float_powi<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: FloatValue<'ctx>, + power: IntValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + let llvm_val_t = val.get_type(); + let llvm_power_t = power.get_type(); + + let fn_name = format!( + "llvm.powi.{}.i{}", + get_float_intrinsic_repr(ctx.ctx, llvm_val_t), + llvm_power_t.get_bit_width(), + ); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_val_t.fn_type(&[llvm_val_t.into(), llvm_power_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.pow`](https://llvm.org/docs/LangRef.html#llvm-pow-intrinsic) intrinsic. +pub fn call_float_pow<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: FloatValue<'ctx>, + power: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + debug_assert_eq!(val.get_type(), power.get_type()); + + let llvm_float_t = val.get_type(); + + let fn_name = format!("llvm.pow.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.fabs`](https://llvm.org/docs/LangRef.html#llvm-fabs-intrinsic) intrinsic. +pub fn call_float_fabs<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + src: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + let llvm_src_t = src.get_type(); + + let fn_name = format!("llvm.fabs.{}", get_float_intrinsic_repr(ctx.ctx, llvm_src_t)); + + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_src_t.fn_type(&[llvm_src_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.minnum`](https://llvm.org/docs/LangRef.html#llvm-minnum-intrinsic) intrinsic. +pub fn call_float_minnum<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val1: FloatValue<'ctx>, + val2: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + debug_assert_eq!(val1.get_type(), val2.get_type()); + + let llvm_float_t = val1.get_type(); + + let fn_name = format!("llvm.minnum.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.maxnum`](https://llvm.org/docs/LangRef.html#llvm-maxnum-intrinsic) intrinsic. +pub fn call_float_maxnum<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val1: FloatValue<'ctx>, + val2: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + debug_assert_eq!(val1.get_type(), val2.get_type()); + + let llvm_float_t = val1.get_type(); + + let fn_name = format!("llvm.maxnum.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.floor`](https://llvm.org/docs/LangRef.html#llvm-floor-intrinsic) intrinsic. +pub fn call_float_floor<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + let llvm_float_t = val.get_type(); + + let fn_name = format!("llvm.floor.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.ceil`](https://llvm.org/docs/LangRef.html#llvm-ceil-intrinsic) intrinsic. +pub fn call_float_ceil<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + let llvm_float_t = val.get_type(); + + let fn_name = format!("llvm.ceil.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.round`](https://llvm.org/docs/LangRef.html#llvm-round-intrinsic) intrinsic. +pub fn call_float_round<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + let llvm_float_t = val.get_type(); + + let fn_name = format!("llvm.round.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the +/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic. +pub fn call_float_roundeven<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: FloatValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + let llvm_float_t = val.get_type(); + + let fn_name = format!("llvm.roundeven.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t)); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`llvm.expect`](https://llvm.org/docs/LangRef.html#llvm-expect-intrinsic) intrinsic. +pub fn call_expect<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + val: IntValue<'ctx>, + expected_val: IntValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + debug_assert_eq!(val.get_type().get_bit_width(), expected_val.get_type().get_bit_width()); + + let llvm_int_t = val.get_type(); + + let fn_name = format!("llvm.expect.i{}", llvm_int_t.get_bit_width()); + let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false); + + ctx.module.add_function(fn_name.as_str(), fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[val.into(), expected_val.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index d69ac2b3..bcc1f4be 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -39,6 +39,7 @@ pub mod concrete_type; pub mod expr; mod generator; pub mod irrt; +pub mod llvm_intrinsics; pub mod stmt; #[cfg(test)] diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index c060ddc2..881ea5f0 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -3,25 +3,12 @@ use crate::{ codegen::{ classes::RangeValue, expr::destructure_range, - irrt::{ - calculate_len_for_slice_range, - call_gamma, - call_gammaln, - call_isinf, - call_isnan, - call_j0, - }, + irrt::*, + llvm_intrinsics::*, stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::numpy::{ - gen_ndarray_empty, - gen_ndarray_eye, - gen_ndarray_full, - gen_ndarray_identity, - gen_ndarray_ones, - gen_ndarray_zeros, - }, + toplevel::numpy::*, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -1010,26 +997,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int32, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); let llvm_i32 = ctx.ctx.i32_type(); let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.round.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val = call_float_round(ctx, arg, None); let val_toint = ctx.builder - .build_float_to_signed_int(val.into_float_value(), llvm_i32, "round") + .build_float_to_signed_int(val, llvm_i32, "round") .unwrap(); Ok(Some(val_toint.into())) }), @@ -1041,26 +1017,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int64, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); let llvm_i64 = ctx.ctx.i64_type(); let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.round.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val = call_float_round(ctx, arg, None); let val_toint = ctx.builder - .build_float_to_signed_int(val.into_float_value(), llvm_i64, "round") + .build_float_to_signed_int(val, llvm_i64, "round") .unwrap(); Ok(Some(val_toint.into())) }), @@ -1072,24 +1037,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { float, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.roundeven.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + let val = call_float_roundeven(ctx, arg, None); - ctx.module.add_function("llvm.roundeven.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(Some(val)) + Ok(Some(val.into())) }), ), Arc::new(RwLock::new(TopLevelDef::Function { @@ -1290,26 +1244,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int32, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); let llvm_i32 = ctx.ctx.i32_type(); let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.floor.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val = call_float_floor(ctx, arg, None); let val_toint = ctx.builder - .build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor") + .build_float_to_signed_int(val, llvm_i32, "floor") .unwrap(); Ok(Some(val_toint.into())) }), @@ -1321,26 +1264,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int64, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); let llvm_i64 = ctx.ctx.i64_type(); let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.floor.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val = call_float_floor(ctx, arg, None); let val_toint = ctx.builder - .build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor") + .build_float_to_signed_int(val, llvm_i64, "floor") .unwrap(); Ok(Some(val_toint.into())) }), @@ -1352,24 +1284,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { float, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.floor.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(Some(val)) + let val = call_float_floor(ctx, arg, None); + Ok(Some(val.into())) }), ), create_fn_by_codegen( @@ -1379,26 +1299,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int32, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); let llvm_i32 = ctx.ctx.i32_type(); let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.ceil.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val = call_float_ceil(ctx, arg, None); let val_toint = ctx.builder - .build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil") + .build_float_to_signed_int(val, llvm_i32, "ceil") .unwrap(); Ok(Some(val_toint.into())) }), @@ -1410,26 +1319,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { int64, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); let llvm_i64 = ctx.ctx.i64_type(); let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.ceil.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val = call_float_ceil(ctx, arg, None); let val_toint = ctx.builder - .build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil") + .build_float_to_signed_int(val, llvm_i64, "ceil") .unwrap(); Ok(Some(val_toint.into())) }), @@ -1441,24 +1339,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { float, &[(float, "n")], Box::new(|ctx, _, _, args, generator| { - let llvm_f64 = ctx.ctx.f64_type(); - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + .to_basic_value_enum(ctx, generator, ctx.primitives.float)? + .into_float_value(); - let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.ceil.f64", fn_type, None) - }); - - let val = ctx - .builder - .build_call(intrinsic_fn, &[arg.into()], "") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(Some(val)) + let val = call_float_ceil(ctx, arg, None); + Ok(Some(val.into())) }), ), Arc::new(RwLock::new({ @@ -1568,40 +1454,38 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let uint32 = ctx.primitives.uint32; let uint64 = ctx.primitives.uint64; let float = ctx.primitives.float; - let llvm_i8 = ctx.ctx.i8_type().as_basic_type_enum(); - let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum(); - let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum(); - let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum(); let m_ty = fun.0.args[0].ty; let n_ty = fun.0.args[1].ty; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); - let (fun_name, arg_ty) = if is_type(m_ty, n_ty) && is_type(n_ty, boolean) { - ("llvm.umin.i8", llvm_i8) - } else if is_type(m_ty, n_ty) && is_type(n_ty, int32) { - ("llvm.smin.i32", llvm_i32) - } else if is_type(m_ty, n_ty) && is_type(n_ty, int64) { - ("llvm.smin.i64", llvm_i64) - } else if is_type(m_ty, n_ty) && is_type(n_ty, uint32) { - ("llvm.umin.i32", llvm_i32) - } else if is_type(m_ty, n_ty) && is_type(n_ty, uint64) { - ("llvm.umin.i64", llvm_i64) + if !is_type(m_ty, n_ty) { + unreachable!() + } + let val: BasicValueEnum = if [boolean, uint32, uint64].iter().any(|t| is_type(n_ty, *t)) { + call_int_umin( + ctx, + m_val.into_int_value(), + n_val.into_int_value(), + Some("min"), + ).into() + } else if [int32, int64].iter().any(|t| is_type(n_ty, *t)) { + call_int_smin( + ctx, + m_val.into_int_value(), + n_val.into_int_value(), + Some("min"), + ).into() } else if is_type(m_ty, n_ty) && is_type(n_ty, float) { - ("llvm.minnum.f64", llvm_f64) + call_float_minnum( + ctx, + m_val.into_float_value(), + n_val.into_float_value(), + Some("min"), + ).into() } else { unreachable!() }; - let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { - let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false); - ctx.module.add_function(fun_name, fn_type, None) - }); - let val = ctx - .builder - .build_call(intrinsic, &[m_val.into(), n_val.into()], "min") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); Ok(val.into()) }, )))), @@ -1630,40 +1514,38 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let uint32 = ctx.primitives.uint32; let uint64 = ctx.primitives.uint64; let float = ctx.primitives.float; - let llvm_i8 = ctx.ctx.i8_type().as_basic_type_enum(); - let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum(); - let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum(); - let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum(); let m_ty = fun.0.args[0].ty; let n_ty = fun.0.args[1].ty; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); - let (fun_name, arg_ty) = if is_type(m_ty, n_ty) && is_type(n_ty, boolean) { - ("llvm.umax.i8", llvm_i8) - } else if is_type(m_ty, n_ty) && is_type(n_ty, int32) { - ("llvm.smax.i32", llvm_i32) - } else if is_type(m_ty, n_ty) && is_type(n_ty, int64) { - ("llvm.smax.i64", llvm_i64) - } else if is_type(m_ty, n_ty) && is_type(n_ty, uint32) { - ("llvm.umax.i32", llvm_i32) - } else if is_type(m_ty, n_ty) && is_type(n_ty, uint64) { - ("llvm.umax.i64", llvm_i64) + if !is_type(m_ty, n_ty) { + unreachable!() + } + let val: BasicValueEnum = if [boolean, uint32, uint64].iter().any(|t| is_type(n_ty, *t)) { + call_int_umax( + ctx, + m_val.into_int_value(), + n_val.into_int_value(), + Some("max"), + ).into() + } else if [int32, int64].iter().any(|t| is_type(n_ty, *t)) { + call_int_smax( + ctx, + m_val.into_int_value(), + n_val.into_int_value(), + Some("max"), + ).into() } else if is_type(m_ty, n_ty) && is_type(n_ty, float) { - ("llvm.maxnum.f64", llvm_f64) + call_float_maxnum( + ctx, + m_val.into_float_value(), + n_val.into_float_value(), + Some("max"), + ).into() } else { unreachable!() }; - let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { - let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false); - ctx.module.add_function(fun_name, fn_type, None) - }); - let val = ctx - .builder - .build_call(intrinsic, &[m_val.into(), n_val.into()], "max") - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); Ok(val.into()) }, )))), @@ -1690,49 +1572,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let uint64 = ctx.primitives.uint64; let float = ctx.primitives.float; let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum(); - let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum(); - let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum(); let n_ty = fun.0.args[0].ty; let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); - let mut is_float = false; - let (fun_name, arg_ty) = - if is_type(n_ty, boolean) || is_type(n_ty, uint32) || is_type(n_ty, uint64) - { - return Ok(n_val.into()); - } else if is_type(n_ty, int32) { - ("llvm.abs.i32", llvm_i32) - } else if is_type(n_ty, int64) { - ("llvm.abs.i64", llvm_i64) - } else if is_type(n_ty, float) { - is_float = true; - ("llvm.fabs.f64", llvm_f64) - } else { - unreachable!() - }; - let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { - let fn_type = if is_float { - arg_ty.fn_type(&[arg_ty.into()], false) - } else { - arg_ty.fn_type(&[arg_ty.into(), llvm_i1.into()], false) - }; - ctx.module.add_function(fun_name, fn_type, None) - }); - let val = ctx - .builder - .build_call( - intrinsic, - &if is_float { - vec![n_val.into()] - } else { - vec![n_val.into(), llvm_i1.const_int(0, false).into()] - }, - "abs", - ) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); + let val: BasicValueEnum = if [boolean, uint32, uint64].iter().any(|t| is_type(n_ty, *t)) { + n_val + } else if [int32, int64].iter().any(|t| is_type(n_ty, *t)) { + call_int_abs( + ctx, + n_val.into_int_value(), + llvm_i1.const_zero(), + Some("abs"), + ).into() + } else if is_type(n_ty, float) { + call_float_fabs( + ctx, + n_val.into_float_value(), + Some("abs"), + ).into() + } else { + unreachable!() + }; Ok(val.into()) }, )))), diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index a842fbf8..a823f4fb 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,4 +1,4 @@ -use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}}; +use inkwell::{IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}}; use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ @@ -11,6 +11,7 @@ use crate::{ call_ndarray_calc_size, call_ndarray_init_dims, }, + llvm_intrinsics::call_memcpy_generic, stmt::gen_for_callback }, symbol_resolver::ValueEnum, @@ -406,7 +407,7 @@ fn call_ndarray_ones_impl<'ctx>( Ok(ndarray) } -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. +/// LLVM-typed implementation for generating the implementation for `ndarray.full`. /// /// * `elem_ty` - The element type of the `NDArray`. /// * `shape` - The `shape` parameter used to construct the `NDArray`. @@ -424,44 +425,17 @@ fn call_ndarray_full_impl<'ctx>( ndarray, |generator, ctx, _| { let value = if fill_value.is_pointer_value() { - let llvm_void = ctx.ctx.void_type(); let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - let memcpy_fn_name = format!( - "llvm.memcpy.p0i8.p0i8.i{}", - generator.get_size_type(ctx.ctx).get_bit_width(), + call_memcpy_generic( + ctx, + copy, + fill_value.into_pointer_value(), + fill_value.get_type().size_of().map(Into::into).unwrap(), + llvm_i1.const_zero(), ); - let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[ - llvm_pi8.into(), - llvm_pi8.into(), - llvm_usize.into(), - llvm_i1.into(), - ], - false, - ); - - ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None) - }); - - ctx.builder - .build_call( - memcpy_fn, - &[ - copy.into(), - fill_value.into(), - fill_value.get_type().size_of().unwrap().into(), - llvm_i1.const_zero().into(), - ], - "", - ) - .unwrap(); copy.into() } else if fill_value.is_int_value() || fill_value.is_float_value() { -- 2.44.2