forked from M-Labs/nac3
1
0
Fork 0

core: Extract LLVM intrinsic functions to their functions

This commit is contained in:
David Mak 2024-02-22 01:47:26 +08:00
parent 4efdd17513
commit 82fdb02d13
7 changed files with 720 additions and 469 deletions

View File

@ -1,6 +1,7 @@
use nac3core::{ use nac3core::{
codegen::{ codegen::{
expr::gen_call, expr::gen_call,
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_with}, stmt::{gen_block, gen_with},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -15,7 +16,7 @@ use inkwell::{
context::Context, context::Context,
module::Linkage, module::Linkage,
types::IntType, types::IntType,
values::{BasicValueEnum, CallSiteValue}, values::BasicValueEnum,
AddressSpace, AddressSpace,
}; };
@ -29,7 +30,6 @@ use std::{
hash::{Hash, Hasher}, hash::{Hash, Hasher},
sync::Arc, sync::Arc,
}; };
use itertools::Either;
/// The parallelism mode within a block. /// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
@ -133,20 +133,12 @@ impl<'a> ArtiqCodeGenerator<'a> {
.unwrap() .unwrap()
.to_basic_value_enum(ctx, self, end.custom.unwrap())?; .to_basic_value_enum(ctx, self, end.custom.unwrap())?;
let now = self.timeline.emit_now_mu(ctx); let now = self.timeline.emit_now_mu(ctx);
let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let max = call_int_smax(
let i64 = ctx.ctx.i64_type(); ctx,
ctx.module.add_function( old_end.into_int_value(),
"llvm.smax.i64", now.into_int_value(),
i64.fn_type(&[i64.into(), i64.into()], false), Some("smax")
None, );
)
});
let max = ctx
.builder
.build_call(smax, &[old_end.into(), now.into()], "smax")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let end_store = self.gen_store_target( let end_store = self.gen_store_target(
ctx, ctx,
&end, &end,
@ -471,18 +463,7 @@ fn rpc_codegen_callback_fn<'ctx>(
let arg_length = args.len() + usize::from(obj.is_some()); let arg_length = args.len() + usize::from(obj.is_some());
let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| { let stackptr = call_stacksave(ctx, Some("rpc.stack"));
ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None)
});
let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| {
ctx.module.add_function(
"llvm.stackrestore",
ctx.ctx.void_type().fn_type(&[ptr_type.into()], false),
None,
)
});
let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack").unwrap();
let args_ptr = ctx.builder let args_ptr = ctx.builder
.build_array_alloca( .build_array_alloca(
ptr_type, ptr_type,
@ -558,13 +539,7 @@ fn rpc_codegen_callback_fn<'ctx>(
.unwrap(); .unwrap();
// reclaim stack space used by arguments // reclaim stack space used by arguments
ctx.builder call_stackrestore(ctx, stackptr);
.build_call(
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
)
.unwrap();
// -- receive value: // -- receive value:
// T result = { // T result = {
@ -624,13 +599,7 @@ fn rpc_codegen_callback_fn<'ctx>(
let result = ctx.builder.build_load(slot, "rpc.result").unwrap(); let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
if need_load { if need_load {
ctx.builder call_stackrestore(ctx, stackptr);
.build_call(
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
)
.unwrap();
} }
Ok(Some(result)) Ok(Some(result))
} }

View File

@ -1,13 +1,13 @@
use inkwell::{ use inkwell::{
IntPredicate, IntPredicate,
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{ArrayValue, BasicValueEnum, CallSiteValue, IntValue, PointerValue}, values::{ArrayValue, BasicValueEnum, IntValue, PointerValue},
}; };
use itertools::Either;
use crate::codegen::{ use crate::codegen::{
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const}, irrt::{call_ndarray_calc_size, call_ndarray_flatten_index, call_ndarray_flatten_index_const},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback, stmt::gen_for_callback,
}; };
@ -924,22 +924,7 @@ impl<'ctx> NDArrayDataProxy<'ctx> {
let indices_len = indices.load_size(ctx, None); let indices_len = indices.load_size(ctx, None);
let ndarray_len = self.0.load_ndims(ctx); let ndarray_len = self.0.load_ndims(ctx);
let min_fn_name = format!("llvm.umin.i{}", llvm_usize.get_bit_width()); let len = call_int_umin(ctx, indices_len, ndarray_len, None);
let min_fn = ctx.module.get_function(min_fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_usize.into(), llvm_usize.into()],
false
);
ctx.module.add_function(min_fn_name.as_str(), fn_type, None)
});
let len = ctx
.builder
.build_call(min_fn, &[indices_len.into(), ndarray_len.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
let i = ctx.builder.build_load(i_addr, "") let i = ctx.builder.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)

View File

@ -8,6 +8,7 @@ use crate::{
get_llvm_type, get_llvm_type,
get_llvm_abi_type, get_llvm_abi_type,
irrt::*, irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
stmt::{gen_raise, gen_var}, stmt::{gen_raise, gen_var},
CodeGenContext, CodeGenTask, CodeGenContext, CodeGenTask,
}, },
@ -30,7 +31,7 @@ use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
}; };
use super::{CodeGenerator, need_sret}; use super::{CodeGenerator, llvm_intrinsics::call_memcpy_generic, need_sret};
pub fn get_subst_key( pub fn get_subst_key(
unifier: &mut Unifier, unifier: &mut Unifier,
@ -371,7 +372,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else { let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
unreachable!() unreachable!()
}; };
let float = self.ctx.f64_type();
match op { match op {
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(), Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(),
Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").map(Into::into).unwrap(), Operator::Sub => self.builder.build_float_sub(lhs, rhs, "fsub").map(Into::into).unwrap(),
@ -380,28 +380,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").map(Into::into).unwrap(), Operator::Mod => self.builder.build_float_rem(lhs, rhs, "fmod").map(Into::into).unwrap(),
Operator::FloorDiv => { Operator::FloorDiv => {
let div = self.builder.build_float_div(lhs, rhs, "fdiv").unwrap(); let div = self.builder.build_float_div(lhs, rhs, "fdiv").unwrap();
let floor_intrinsic = call_float_floor(self, div, Some("floor")).into()
self.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = float.fn_type(&[float.into()], false);
self.module.add_function("llvm.floor.f64", fn_type, None)
});
self.builder
.build_call(floor_intrinsic, &[div.into()], "floor")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
}
Operator::Pow => {
let pow_intrinsic = self.module.get_function("llvm.pow.f64").unwrap_or_else(|| {
let fn_type = float.fn_type(&[float.into(), float.into()], false);
self.module.add_function("llvm.pow.f64", fn_type, None)
});
self.builder
.build_call(pow_intrinsic, &[lhs.into(), rhs.into()], "f_pow")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
} }
Operator::Pow => call_float_pow(self, lhs, rhs, Some("f_pow")).into(),
// special implementation? // special implementation?
_ => unimplemented!(), _ => unimplemented!(),
} }
@ -585,24 +566,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
) { ) {
let i1 = self.ctx.bool_type(); let i1 = self.ctx.bool_type();
let i1_true = i1.const_all_ones(); let i1_true = i1.const_all_ones();
let expect_fun = self.module.get_function("llvm.expect.i1").unwrap_or_else(|| {
self.module.add_function(
"llvm.expect.i1",
i1.fn_type(&[i1.into(), i1.into()], false),
None,
)
});
// we assume that the condition is most probably true, so the normal path is the most // we assume that the condition is most probably true, so the normal path is the most
// probable path // probable path
// even if this assumption is violated, it does not matter as exception unwinding is // even if this assumption is violated, it does not matter as exception unwinding is
// slow anyway... // slow anyway...
let cond = self let cond = call_expect(self, cond, i1_true, Some("expect"));
.builder
.build_call(expect_fun, &[cond.into(), i1_true.into()], "expect")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current_fun, "succ"); let then_block = self.ctx.append_basic_block(current_fun, "succ");
let exn_block = self.ctx.append_basic_block(current_fun, "fail"); let exn_block = self.ctx.append_basic_block(current_fun, "fail");
@ -1150,17 +1118,12 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
// Pow is the only operator that would pass typecheck between float and int // Pow is the only operator that would pass typecheck between float and int
assert_eq!(*op, Operator::Pow); assert_eq!(*op, Operator::Pow);
let i32_t = ctx.ctx.i32_type(); let res = call_float_powi(
let pow_intr = ctx.module.get_function("llvm.powi.f64.i32").unwrap_or_else(|| { ctx,
let f64_t = ctx.ctx.f64_type(); left_val.into_float_value(),
let ty = f64_t.fn_type(&[f64_t.into(), i32_t.into()], false); right_val.into_int_value(),
ctx.module.add_function("llvm.powi.f64.i32", ty, None) Some("f_pow_i")
}); );
let res = ctx.builder
.build_call(pow_intr, &[left_val.into(), right_val.into()], "f_pow_i")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(Some(res.into())) Ok(Some(res.into()))
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap());
@ -1229,11 +1192,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
v: NDArrayValue<'ctx>, v: NDArrayValue<'ctx>,
slice: &Expr<Option<Type>>, slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_void = ctx.ctx.void_type();
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
unreachable!() unreachable!()
@ -1333,24 +1293,6 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
let memcpy_fn_name = format!(
"llvm.memcpy.p0i8.p0i8.i{}",
generator.get_size_type(ctx.ctx).get_bit_width(),
);
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pi8.into(),
llvm_pi8.into(),
llvm_usize.into(),
llvm_i1.into(),
],
false,
);
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = v.get_dims().ptr_offset( let v_dims_src_ptr = v.get_dims().ptr_offset(
ctx, ctx,
@ -1358,37 +1300,16 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
None, None,
); );
ctx.builder.build_call( call_memcpy_generic(
memcpy_fn, ctx,
&[
ctx.builder
.build_bitcast(
ndarray.get_dims().get_ptr(ctx), ndarray.get_dims().get_ptr(ctx),
llvm_pi8,
"",
)
.map(Into::into)
.unwrap(),
ctx.builder
.build_bitcast(
v_dims_src_ptr, v_dims_src_ptr,
llvm_pi8,
"",
)
.map(Into::into)
.unwrap(),
ctx.builder ctx.builder
.build_int_mul( .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
ndarray_num_dims,
llvm_usize.size_of(),
"",
)
.map(Into::into) .map(Into::into)
.unwrap(), .unwrap(),
llvm_i1.const_zero().into(), llvm_i1.const_zero(),
], );
"",
).unwrap();
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
@ -1404,37 +1325,16 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ctx.ctx.i32_type().const_array(&[index]), ctx.ctx.i32_type().const_array(&[index]),
None None
); );
ctx.builder.build_call( call_memcpy_generic(
memcpy_fn, ctx,
&[
ctx.builder
.build_bitcast(
ndarray.get_data().get_ptr(ctx), ndarray.get_data().get_ptr(ctx),
llvm_pi8,
"",
)
.map(Into::into)
.unwrap(),
ctx.builder
.build_bitcast(
v_data_src_ptr, v_data_src_ptr,
llvm_pi8,
"",
)
.map(Into::into)
.unwrap(),
ctx.builder ctx.builder
.build_int_mul( .build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into) .map(Into::into)
.unwrap(), .unwrap(),
llvm_i1.const_zero().into(), llvm_i1.const_zero(),
], );
"",
).unwrap();
Ok(Some(v.get_ptr().into())) Ok(Some(v.get_ptr().into()))
} }

View File

@ -0,0 +1,562 @@
use inkwell::AddressSpace;
use inkwell::context::Context;
use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16"
}
if ft == ctx.f32_type() {
return "f32"
}
if ft == ctx.f64_type() {
return "f64"
}
if ft == ctx.f128_type() {
return "f128"
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80"
}
if ft == ctx.ppc_f128_type() {
return "ppcf128"
}
unreachable!()
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_p0i8.fn_type(&[], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[ptr.into()], "")
.unwrap();
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let llvm_src_t = src.get_type();
let fn_name = format!("llvm.abs.i{}", llvm_src_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let llvm_i1 = ctx.ctx.bool_type();
let fn_type = llvm_src_t.fn_type(&[llvm_src_t.into(), llvm_i1.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[src.into(), is_int_min_poison.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.smax`](https://llvm.org/docs/LangRef.html#llvm-smax-intrinsic) intrinsic.
pub fn call_int_smax<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.smax.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.smin`](https://llvm.org/docs/LangRef.html#llvm-smin-intrinsic) intrinsic.
pub fn call_int_smin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.smin.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.umax`](https://llvm.org/docs/LangRef.html#llvm-umax-intrinsic) intrinsic.
pub fn call_int_umax<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.umax.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.umin`](https://llvm.org/docs/LangRef.html#llvm-umin-intrinsic) intrinsic.
pub fn call_int_umin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let fn_name = format!("llvm.umin.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let fn_name = format!(
"llvm.memcpy.p0i{}.p0i{}.i{}",
llvm_dest_t.get_element_type().into_int_type().get_bit_width(),
llvm_src_t.get_element_type().into_int_type().get_bit_width(),
llvm_len_t.get_bit_width(),
);
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let fn_type = llvm_void.fn_type(
&[
llvm_dest_t.into(),
llvm_src_t.into(),
llvm_len_t.into(),
is_volatile.get_type().into(),
],
false,
);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bitcast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bitcast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let fn_name = format!(
"llvm.powi.{}.i{}",
get_float_intrinsic_repr(ctx.ctx, llvm_val_t),
llvm_power_t.get_bit_width(),
);
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_val_t.fn_type(&[llvm_val_t.into(), llvm_power_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.pow`](https://llvm.org/docs/LangRef.html#llvm-pow-intrinsic) intrinsic.
pub fn call_float_pow<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(val.get_type(), power.get_type());
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.pow.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.fabs`](https://llvm.org/docs/LangRef.html#llvm-fabs-intrinsic) intrinsic.
pub fn call_float_fabs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_src_t = src.get_type();
let fn_name = format!("llvm.fabs.{}", get_float_intrinsic_repr(ctx.ctx, llvm_src_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_src_t.fn_type(&[llvm_src_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.minnum`](https://llvm.org/docs/LangRef.html#llvm-minnum-intrinsic) intrinsic.
pub fn call_float_minnum<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val1: FloatValue<'ctx>,
val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(val1.get_type(), val2.get_type());
let llvm_float_t = val1.get_type();
let fn_name = format!("llvm.minnum.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.maxnum`](https://llvm.org/docs/LangRef.html#llvm-maxnum-intrinsic) intrinsic.
pub fn call_float_maxnum<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val1: FloatValue<'ctx>,
val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!(val1.get_type(), val2.get_type());
let llvm_float_t = val1.get_type();
let fn_name = format!("llvm.maxnum.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into(), llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.floor`](https://llvm.org/docs/LangRef.html#llvm-floor-intrinsic) intrinsic.
pub fn call_float_floor<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.floor.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.ceil`](https://llvm.org/docs/LangRef.html#llvm-ceil-intrinsic) intrinsic.
pub fn call_float_ceil<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.ceil.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.round`](https://llvm.org/docs/LangRef.html#llvm-round-intrinsic) intrinsic.
pub fn call_float_round<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.round.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic.
pub fn call_float_roundeven<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_float_t = val.get_type();
let fn_name = format!("llvm.roundeven.{}", get_float_intrinsic_repr(ctx.ctx, llvm_float_t));
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_float_t.fn_type(&[llvm_float_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.expect`](https://llvm.org/docs/LangRef.html#llvm-expect-intrinsic) intrinsic.
pub fn call_expect<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
expected_val: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(val.get_type().get_bit_width(), expected_val.get_type().get_bit_width());
let llvm_int_t = val.get_type();
let fn_name = format!("llvm.expect.i{}", llvm_int_t.get_bit_width());
let intrinsic_fn = ctx.module.get_function(fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_int_t.fn_type(&[llvm_int_t.into(), llvm_int_t.into()], false);
ctx.module.add_function(fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[val.into(), expected_val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -39,6 +39,7 @@ pub mod concrete_type;
pub mod expr; pub mod expr;
mod generator; mod generator;
pub mod irrt; pub mod irrt;
pub mod llvm_intrinsics;
pub mod stmt; pub mod stmt;
#[cfg(test)] #[cfg(test)]

View File

@ -3,25 +3,12 @@ use crate::{
codegen::{ codegen::{
classes::RangeValue, classes::RangeValue,
expr::destructure_range, expr::destructure_range,
irrt::{ irrt::*,
calculate_len_for_slice_range, llvm_intrinsics::*,
call_gamma,
call_gammaln,
call_isinf,
call_isnan,
call_j0,
},
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::numpy::{ toplevel::numpy::*,
gen_ndarray_empty,
gen_ndarray_eye,
gen_ndarray_full,
gen_ndarray_identity,
gen_ndarray_ones,
gen_ndarray_zeros,
},
}; };
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -1010,26 +997,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int32, int32,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { let val = call_float_round(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let val_toint = ctx.builder let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "round") .build_float_to_signed_int(val, llvm_i32, "round")
.unwrap(); .unwrap();
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
@ -1041,26 +1017,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int64, int64,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type(); let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { let val = call_float_round(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let val_toint = ctx.builder let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "round") .build_float_to_signed_int(val, llvm_i64, "round")
.unwrap(); .unwrap();
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
@ -1072,24 +1037,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
float, float,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.roundeven.f64").unwrap_or_else(|| { let val = call_float_roundeven(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.roundeven.f64", fn_type, None) Ok(Some(val.into()))
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(Some(val))
}), }),
), ),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
@ -1290,26 +1244,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int32, int32,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { let val = call_float_floor(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let val_toint = ctx.builder let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor") .build_float_to_signed_int(val, llvm_i32, "floor")
.unwrap(); .unwrap();
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
@ -1321,26 +1264,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int64, int64,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type(); let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { let val = call_float_floor(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let val_toint = ctx.builder let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor") .build_float_to_signed_int(val, llvm_i64, "floor")
.unwrap(); .unwrap();
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
@ -1352,24 +1284,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
float, float,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { let val = call_float_floor(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); Ok(Some(val.into()))
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(Some(val))
}), }),
), ),
create_fn_by_codegen( create_fn_by_codegen(
@ -1379,26 +1299,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int32, int32,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { let val = call_float_ceil(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let val_toint = ctx.builder let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil") .build_float_to_signed_int(val, llvm_i32, "ceil")
.unwrap(); .unwrap();
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
@ -1410,26 +1319,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int64, int64,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type(); let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { let val = call_float_ceil(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
let val_toint = ctx.builder let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil") .build_float_to_signed_int(val, llvm_i64, "ceil")
.unwrap(); .unwrap();
Ok(Some(val_toint.into())) Ok(Some(val_toint.into()))
}), }),
@ -1441,24 +1339,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
float, float,
&[(float, "n")], &[(float, "n")],
Box::new(|ctx, _, _, args, generator| { Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let arg = args[0].1.clone() let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?; .to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { let val = call_float_ceil(ctx, arg, None);
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); Ok(Some(val.into()))
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(Some(val))
}), }),
), ),
Arc::new(RwLock::new({ Arc::new(RwLock::new({
@ -1568,40 +1454,38 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let uint32 = ctx.primitives.uint32; let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64; let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float; let float = ctx.primitives.float;
let llvm_i8 = ctx.ctx.i8_type().as_basic_type_enum();
let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum();
let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum();
let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum();
let m_ty = fun.0.args[0].ty; let m_ty = fun.0.args[0].ty;
let n_ty = fun.0.args[1].ty; let n_ty = fun.0.args[1].ty;
let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?;
let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b);
let (fun_name, arg_ty) = if is_type(m_ty, n_ty) && is_type(n_ty, boolean) { if !is_type(m_ty, n_ty) {
("llvm.umin.i8", llvm_i8) unreachable!()
} else if is_type(m_ty, n_ty) && is_type(n_ty, int32) { }
("llvm.smin.i32", llvm_i32) let val: BasicValueEnum = if [boolean, uint32, uint64].iter().any(|t| is_type(n_ty, *t)) {
} else if is_type(m_ty, n_ty) && is_type(n_ty, int64) { call_int_umin(
("llvm.smin.i64", llvm_i64) ctx,
} else if is_type(m_ty, n_ty) && is_type(n_ty, uint32) { m_val.into_int_value(),
("llvm.umin.i32", llvm_i32) n_val.into_int_value(),
} else if is_type(m_ty, n_ty) && is_type(n_ty, uint64) { Some("min"),
("llvm.umin.i64", llvm_i64) ).into()
} else if [int32, int64].iter().any(|t| is_type(n_ty, *t)) {
call_int_smin(
ctx,
m_val.into_int_value(),
n_val.into_int_value(),
Some("min"),
).into()
} else if is_type(m_ty, n_ty) && is_type(n_ty, float) { } else if is_type(m_ty, n_ty) && is_type(n_ty, float) {
("llvm.minnum.f64", llvm_f64) call_float_minnum(
ctx,
m_val.into_float_value(),
n_val.into_float_value(),
Some("min"),
).into()
} else { } else {
unreachable!() unreachable!()
}; };
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false);
ctx.module.add_function(fun_name, fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic, &[m_val.into(), n_val.into()], "min")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(val.into()) Ok(val.into())
}, },
)))), )))),
@ -1630,40 +1514,38 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let uint32 = ctx.primitives.uint32; let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64; let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float; let float = ctx.primitives.float;
let llvm_i8 = ctx.ctx.i8_type().as_basic_type_enum();
let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum();
let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum();
let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum();
let m_ty = fun.0.args[0].ty; let m_ty = fun.0.args[0].ty;
let n_ty = fun.0.args[1].ty; let n_ty = fun.0.args[1].ty;
let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?;
let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b);
let (fun_name, arg_ty) = if is_type(m_ty, n_ty) && is_type(n_ty, boolean) { if !is_type(m_ty, n_ty) {
("llvm.umax.i8", llvm_i8) unreachable!()
} else if is_type(m_ty, n_ty) && is_type(n_ty, int32) { }
("llvm.smax.i32", llvm_i32) let val: BasicValueEnum = if [boolean, uint32, uint64].iter().any(|t| is_type(n_ty, *t)) {
} else if is_type(m_ty, n_ty) && is_type(n_ty, int64) { call_int_umax(
("llvm.smax.i64", llvm_i64) ctx,
} else if is_type(m_ty, n_ty) && is_type(n_ty, uint32) { m_val.into_int_value(),
("llvm.umax.i32", llvm_i32) n_val.into_int_value(),
} else if is_type(m_ty, n_ty) && is_type(n_ty, uint64) { Some("max"),
("llvm.umax.i64", llvm_i64) ).into()
} else if [int32, int64].iter().any(|t| is_type(n_ty, *t)) {
call_int_smax(
ctx,
m_val.into_int_value(),
n_val.into_int_value(),
Some("max"),
).into()
} else if is_type(m_ty, n_ty) && is_type(n_ty, float) { } else if is_type(m_ty, n_ty) && is_type(n_ty, float) {
("llvm.maxnum.f64", llvm_f64) call_float_maxnum(
ctx,
m_val.into_float_value(),
n_val.into_float_value(),
Some("max"),
).into()
} else { } else {
unreachable!() unreachable!()
}; };
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false);
ctx.module.add_function(fun_name, fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic, &[m_val.into(), n_val.into()], "max")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(val.into()) Ok(val.into())
}, },
)))), )))),
@ -1690,49 +1572,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let uint64 = ctx.primitives.uint64; let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float; let float = ctx.primitives.float;
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum();
let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum();
let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum();
let n_ty = fun.0.args[0].ty; let n_ty = fun.0.args[0].ty;
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b);
let mut is_float = false; let val: BasicValueEnum = if [boolean, uint32, uint64].iter().any(|t| is_type(n_ty, *t)) {
let (fun_name, arg_ty) = n_val
if is_type(n_ty, boolean) || is_type(n_ty, uint32) || is_type(n_ty, uint64) } else if [int32, int64].iter().any(|t| is_type(n_ty, *t)) {
{ call_int_abs(
return Ok(n_val.into()); ctx,
} else if is_type(n_ty, int32) { n_val.into_int_value(),
("llvm.abs.i32", llvm_i32) llvm_i1.const_zero(),
} else if is_type(n_ty, int64) { Some("abs"),
("llvm.abs.i64", llvm_i64) ).into()
} else if is_type(n_ty, float) { } else if is_type(n_ty, float) {
is_float = true; call_float_fabs(
("llvm.fabs.f64", llvm_f64) ctx,
n_val.into_float_value(),
Some("abs"),
).into()
} else { } else {
unreachable!() unreachable!()
}; };
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = if is_float {
arg_ty.fn_type(&[arg_ty.into()], false)
} else {
arg_ty.fn_type(&[arg_ty.into(), llvm_i1.into()], false)
};
ctx.module.add_function(fun_name, fn_type, None)
});
let val = ctx
.builder
.build_call(
intrinsic,
&if is_float {
vec![n_val.into()]
} else {
vec![n_val.into(), llvm_i1.const_int(0, false).into()]
},
"abs",
)
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap();
Ok(val.into()) Ok(val.into())
}, },
)))), )))),

View File

@ -1,4 +1,4 @@
use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}}; use inkwell::{IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue}; use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue};
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use crate::{ use crate::{
@ -11,6 +11,7 @@ use crate::{
call_ndarray_calc_size, call_ndarray_calc_size,
call_ndarray_init_dims, call_ndarray_init_dims,
}, },
llvm_intrinsics::call_memcpy_generic,
stmt::gen_for_callback stmt::gen_for_callback
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
@ -406,7 +407,7 @@ fn call_ndarray_ones_impl<'ctx>(
Ok(ndarray) Ok(ndarray)
} }
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. /// LLVM-typed implementation for generating the implementation for `ndarray.full`.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`. /// * `shape` - The `shape` parameter used to construct the `NDArray`.
@ -424,44 +425,17 @@ fn call_ndarray_full_impl<'ctx>(
ndarray, ndarray,
|generator, ctx, _| { |generator, ctx, _| {
let value = if fill_value.is_pointer_value() { let value = if fill_value.is_pointer_value() {
let llvm_void = ctx.ctx.void_type();
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
let memcpy_fn_name = format!( call_memcpy_generic(
"llvm.memcpy.p0i8.p0i8.i{}", ctx,
generator.get_size_type(ctx.ctx).get_bit_width(), copy,
fill_value.into_pointer_value(),
fill_value.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
); );
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pi8.into(),
llvm_pi8.into(),
llvm_usize.into(),
llvm_i1.into(),
],
false,
);
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
});
ctx.builder
.build_call(
memcpy_fn,
&[
copy.into(),
fill_value.into(),
fill_value.get_type().size_of().unwrap().into(),
llvm_i1.const_zero().into(),
],
"",
)
.unwrap();
copy.into() copy.into()
} else if fill_value.is_int_value() || fill_value.is_float_value() { } else if fill_value.is_int_value() || fill_value.is_float_value() {