forked from M-Labs/nac3
221 lines
8.6 KiB
Rust
221 lines
8.6 KiB
Rust
use inkwell::attributes::{Attribute, AttributeLoc};
|
|
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
|
use itertools::Either;
|
|
|
|
use crate::codegen::CodeGenContext;
|
|
|
|
/// Macro to generate extern function
|
|
/// Both function return type and function parameter type are `FloatValue`
|
|
///
|
|
/// Arguments:
|
|
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
|
|
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
|
/// * `$extern_fn:literal`: Name of underlying extern function
|
|
///
|
|
/// Optional Arguments:
|
|
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function
|
|
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly"
|
|
/// These will be used unless other attributes are specified
|
|
/// * `$(,$args:ident)*`: Operands of the extern function
|
|
/// The data type of these operands will be set to `FloatValue`
|
|
///
|
|
macro_rules! generate_extern_fn {
|
|
("unary", $fn_name:ident, $extern_fn:literal) => {
|
|
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
|
|
};
|
|
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
|
|
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
|
|
};
|
|
("binary", $fn_name:ident, $extern_fn:literal) => {
|
|
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
|
|
};
|
|
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
|
|
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
|
|
};
|
|
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
|
|
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
|
|
pub fn $fn_name<'ctx>(
|
|
ctx: &CodeGenContext<'ctx, '_>
|
|
$(,$args: FloatValue<'ctx>)*,
|
|
name: Option<&str>,
|
|
) -> FloatValue<'ctx> {
|
|
const FN_NAME: &str = $extern_fn;
|
|
|
|
let llvm_f64 = ctx.ctx.f64_type();
|
|
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
for attr in [$($attributes),*] {
|
|
func.add_attribute(
|
|
AttributeLoc::Function,
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
);
|
|
}
|
|
func
|
|
});
|
|
|
|
ctx.builder
|
|
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
|
|
.map(CallSiteValue::try_as_basic_value)
|
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
.map(Either::unwrap_left)
|
|
.unwrap()
|
|
}
|
|
};
|
|
}
|
|
|
|
generate_extern_fn!("unary", call_tan, "tan");
|
|
generate_extern_fn!("unary", call_asin, "asin");
|
|
generate_extern_fn!("unary", call_acos, "acos");
|
|
generate_extern_fn!("unary", call_atan, "atan");
|
|
generate_extern_fn!("unary", call_sinh, "sinh");
|
|
generate_extern_fn!("unary", call_cosh, "cosh");
|
|
generate_extern_fn!("unary", call_tanh, "tanh");
|
|
generate_extern_fn!("unary", call_asinh, "asinh");
|
|
generate_extern_fn!("unary", call_acosh, "acosh");
|
|
generate_extern_fn!("unary", call_atanh, "atanh");
|
|
generate_extern_fn!("unary", call_expm1, "expm1");
|
|
generate_extern_fn!(
|
|
"unary",
|
|
call_cbrt,
|
|
"cbrt",
|
|
"mustprogress",
|
|
"nofree",
|
|
"nosync",
|
|
"nounwind",
|
|
"readonly",
|
|
"willreturn"
|
|
);
|
|
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
|
|
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
|
|
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
|
|
|
|
generate_extern_fn!("binary", call_atan2, "atan2");
|
|
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
|
|
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
|
|
|
|
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
|
|
pub fn call_ldexp<'ctx>(
|
|
ctx: &CodeGenContext<'ctx, '_>,
|
|
arg: FloatValue<'ctx>,
|
|
exp: IntValue<'ctx>,
|
|
name: Option<&str>,
|
|
) -> FloatValue<'ctx> {
|
|
const FN_NAME: &str = "ldexp";
|
|
|
|
let llvm_f64 = ctx.ctx.f64_type();
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
debug_assert_eq!(arg.get_type(), llvm_f64);
|
|
debug_assert_eq!(exp.get_type(), llvm_i32);
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
|
func.add_attribute(
|
|
AttributeLoc::Function,
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
);
|
|
}
|
|
|
|
func
|
|
});
|
|
|
|
ctx.builder
|
|
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
|
|
.map(CallSiteValue::try_as_basic_value)
|
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
.map(Either::unwrap_left)
|
|
.unwrap()
|
|
}
|
|
|
|
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
|
pub fn call_linalg_try_invert_to<'ctx>(
|
|
ctx: &CodeGenContext<'ctx, '_>,
|
|
dim0: IntValue<'ctx>,
|
|
dim1: IntValue<'ctx>,
|
|
data: PointerValue<'ctx>,
|
|
name: Option<&str>,
|
|
) -> IntValue<'ctx> {
|
|
const FN_NAME: &str = "linalg_try_invert_to";
|
|
|
|
let llvm_f64 = ctx.ctx.f64_type();
|
|
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
|
|
|
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
|
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
|
|
|
debug_assert!(allowed_dim0);
|
|
debug_assert!(allowed_dim1);
|
|
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
let fn_type = ctx.ctx.i8_type().fn_type(
|
|
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
|
false,
|
|
);
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
|
func.add_attribute(
|
|
AttributeLoc::Function,
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
);
|
|
}
|
|
|
|
func
|
|
});
|
|
|
|
ctx.builder
|
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
|
.map(CallSiteValue::try_as_basic_value)
|
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
.map(Either::unwrap_left)
|
|
.unwrap()
|
|
}
|
|
|
|
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
|
pub fn call_linalg_wilkinson_shift<'ctx>(
|
|
ctx: &CodeGenContext<'ctx, '_>,
|
|
dim0: IntValue<'ctx>,
|
|
dim1: IntValue<'ctx>,
|
|
data: PointerValue<'ctx>,
|
|
name: Option<&str>,
|
|
) -> FloatValue<'ctx> {
|
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
|
|
|
let llvm_f64 = ctx.ctx.f64_type();
|
|
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
|
|
|
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
|
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
|
|
|
debug_assert!(allowed_dim0);
|
|
debug_assert!(allowed_dim1);
|
|
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
let fn_type = ctx.ctx.f64_type().fn_type(
|
|
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
|
false,
|
|
);
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
|
func.add_attribute(
|
|
AttributeLoc::Function,
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
);
|
|
}
|
|
|
|
func
|
|
});
|
|
|
|
ctx.builder
|
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
|
.map(CallSiteValue::try_as_basic_value)
|
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
.map(Either::unwrap_left)
|
|
.unwrap()
|
|
}
|