2024-10-03 12:37:56 +08:00
|
|
|
use inkwell::{
|
|
|
|
attributes::{Attribute, AttributeLoc},
|
|
|
|
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
|
|
|
};
|
2024-04-24 17:40:25 +08:00
|
|
|
use itertools::Either;
|
|
|
|
|
|
|
|
use crate::codegen::CodeGenContext;
|
|
|
|
|
2024-07-09 16:31:08 +08:00
|
|
|
/// Macro to generate extern function
|
|
|
|
/// Both function return type and function parameter type are `FloatValue`
|
|
|
|
///
|
|
|
|
/// Arguments:
|
|
|
|
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
|
|
|
|
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
|
|
|
/// * `$extern_fn:literal`: Name of underlying extern function
|
|
|
|
///
|
|
|
|
/// Optional Arguments:
|
2024-08-21 11:10:52 +08:00
|
|
|
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
|
|
|
|
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
|
|
|
|
/// These will be used unless other attributes are specified
|
2024-07-09 16:31:08 +08:00
|
|
|
/// * `$(,$args:ident)*`: Operands of the extern function
|
2024-08-21 11:10:52 +08:00
|
|
|
/// The data type of these operands will be set to `FloatValue`
|
2024-07-09 16:31:08 +08:00
|
|
|
///
|
|
|
|
macro_rules! generate_extern_fn {
|
|
|
|
("unary", $fn_name:ident, $extern_fn:literal) => {
|
|
|
|
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
|
|
|
|
};
|
|
|
|
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
|
|
|
|
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
|
|
|
|
};
|
|
|
|
("binary", $fn_name:ident, $extern_fn:literal) => {
|
|
|
|
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
|
|
|
|
};
|
|
|
|
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
|
|
|
|
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
|
|
|
|
};
|
|
|
|
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
|
|
|
|
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
|
|
|
|
pub fn $fn_name<'ctx>(
|
|
|
|
ctx: &CodeGenContext<'ctx, '_>
|
|
|
|
$(,$args: FloatValue<'ctx>)*,
|
|
|
|
name: Option<&str>,
|
|
|
|
) -> FloatValue<'ctx> {
|
|
|
|
const FN_NAME: &str = $extern_fn;
|
|
|
|
|
|
|
|
let llvm_f64 = ctx.ctx.f64_type();
|
|
|
|
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
|
|
|
|
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
|
|
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
|
|
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
|
|
for attr in [$($attributes),*] {
|
|
|
|
func.add_attribute(
|
|
|
|
AttributeLoc::Function,
|
|
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
func
|
|
|
|
});
|
|
|
|
|
|
|
|
ctx.builder
|
|
|
|
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
|
|
|
|
.map(CallSiteValue::try_as_basic_value)
|
|
|
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
|
|
.map(Either::unwrap_left)
|
|
|
|
.unwrap()
|
2024-04-24 17:40:25 +08:00
|
|
|
}
|
2024-07-09 16:31:08 +08:00
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
generate_extern_fn!("unary", call_tan, "tan");
|
|
|
|
generate_extern_fn!("unary", call_asin, "asin");
|
|
|
|
generate_extern_fn!("unary", call_acos, "acos");
|
|
|
|
generate_extern_fn!("unary", call_atan, "atan");
|
|
|
|
generate_extern_fn!("unary", call_sinh, "sinh");
|
|
|
|
generate_extern_fn!("unary", call_cosh, "cosh");
|
|
|
|
generate_extern_fn!("unary", call_tanh, "tanh");
|
|
|
|
generate_extern_fn!("unary", call_asinh, "asinh");
|
|
|
|
generate_extern_fn!("unary", call_acosh, "acosh");
|
|
|
|
generate_extern_fn!("unary", call_atanh, "atanh");
|
|
|
|
generate_extern_fn!("unary", call_expm1, "expm1");
|
|
|
|
generate_extern_fn!(
|
|
|
|
"unary",
|
|
|
|
call_cbrt,
|
|
|
|
"cbrt",
|
|
|
|
"mustprogress",
|
|
|
|
"nofree",
|
|
|
|
"nosync",
|
|
|
|
"nounwind",
|
|
|
|
"readonly",
|
|
|
|
"willreturn"
|
|
|
|
);
|
|
|
|
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
|
|
|
|
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
|
|
|
|
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
|
|
|
|
|
|
|
|
generate_extern_fn!("binary", call_atan2, "atan2");
|
|
|
|
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
|
|
|
|
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
|
2024-04-24 17:40:25 +08:00
|
|
|
|
|
|
|
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
|
|
|
|
pub fn call_ldexp<'ctx>(
|
|
|
|
ctx: &CodeGenContext<'ctx, '_>,
|
|
|
|
arg: FloatValue<'ctx>,
|
|
|
|
exp: IntValue<'ctx>,
|
|
|
|
name: Option<&str>,
|
|
|
|
) -> FloatValue<'ctx> {
|
|
|
|
const FN_NAME: &str = "ldexp";
|
|
|
|
|
|
|
|
let llvm_f64 = ctx.ctx.f64_type();
|
|
|
|
let llvm_i32 = ctx.ctx.i32_type();
|
|
|
|
debug_assert_eq!(arg.get_type(), llvm_f64);
|
|
|
|
debug_assert_eq!(exp.get_type(), llvm_i32);
|
|
|
|
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
|
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
|
|
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
|
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
|
|
|
func.add_attribute(
|
|
|
|
AttributeLoc::Function,
|
2024-06-12 14:45:03 +08:00
|
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
2024-04-24 17:40:25 +08:00
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
func
|
|
|
|
});
|
|
|
|
|
|
|
|
ctx.builder
|
|
|
|
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
|
|
|
|
.map(CallSiteValue::try_as_basic_value)
|
|
|
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
|
|
.map(Either::unwrap_left)
|
|
|
|
.unwrap()
|
|
|
|
}
|
2024-07-25 12:16:53 +08:00
|
|
|
|
|
|
|
/// Macro to generate `np_linalg` and `sp_linalg` functions
|
|
|
|
/// The function takes as input `NDArray` and returns ()
|
|
|
|
///
|
|
|
|
/// Arguments:
|
|
|
|
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
|
|
|
/// * `$extern_fn:literal`: Name of underlying extern function
|
|
|
|
/// * (2/3/4): Number of `NDArray` that function takes as input
|
|
|
|
///
|
|
|
|
/// Note:
|
|
|
|
/// The operands and resulting `NDArray` are both passed as input to the funcion
|
|
|
|
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
|
|
|
|
/// The function changes the content of the output `NDArray` in-place
|
|
|
|
macro_rules! generate_linalg_extern_fn {
|
|
|
|
($fn_name:ident, $extern_fn:literal, 2) => {
|
|
|
|
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
|
|
|
|
};
|
|
|
|
($fn_name:ident, $extern_fn:literal, 3) => {
|
|
|
|
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
|
|
|
|
};
|
|
|
|
($fn_name:ident, $extern_fn:literal, 4) => {
|
|
|
|
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
|
|
|
|
};
|
|
|
|
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
|
|
|
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
|
|
|
|
pub fn $fn_name<'ctx>(
|
|
|
|
ctx: &mut CodeGenContext<'ctx, '_>
|
|
|
|
$(,$input_matrix: BasicValueEnum<'ctx>)*,
|
|
|
|
name: Option<&str>,
|
|
|
|
){
|
|
|
|
const FN_NAME: &str = $extern_fn;
|
|
|
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
|
|
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
|
|
|
|
|
|
|
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
|
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
|
|
|
func.add_attribute(
|
|
|
|
AttributeLoc::Function,
|
|
|
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
func
|
|
|
|
});
|
|
|
|
|
|
|
|
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
|
|
|
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
|
|
|
|
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
|
|
|
|
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
|
|
|
|
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
2024-07-31 18:02:54 +08:00
|
|
|
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
|
|
|
|
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
|
2024-07-25 12:16:53 +08:00
|
|
|
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
|
|
|
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
|
|
|
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|