nac3/nac3core/src/codegen/extern_fns.rs

221 lines
8.6 KiB
Rust
Raw Normal View History

use inkwell::attributes::{Attribute, AttributeLoc};
2024-07-22 13:19:01 +08:00
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly"
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
2024-06-12 14:45:03 +08:00
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
2024-07-22 13:19:01 +08:00
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
pub fn call_linalg_try_invert_to<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "linalg_try_invert_to";
let llvm_f64 = ctx.ctx.f64_type();
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.i8_type().fn_type(
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
false,
);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
pub fn call_linalg_wilkinson_shift<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let llvm_f64 = ctx.ctx.f64_type();
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.f64_type().fn_type(
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
false,
);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}