core: revise nac3core/codegen/builtin_fns.rs numpy macros

This commit is contained in:
lyken 2024-06-20 10:26:15 +08:00
parent ef2502e7b4
commit 777077a723
1 changed files with 231 additions and 67 deletions

View File

@ -1044,25 +1044,25 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
})
}
/// Helper function to create a builtin that takes in either an ndarray or a value and returns a value of the same structure.
/// (e.g, `float` to `float`, `float` to `int`, `ndarray<float>` to `ndarray<float>`, or even `ndarray<float>` to `ndarray<int>`).
/// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar.
///
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
/// * `get_ret_elem_type`: A function that takes in the input element [`Type`], and returns the correct return [`Type`].
/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
/// Return a constant [`Type`] here if the return type does not depend on the input type.
/// * `on_elem`: The function to be called when the input argument is not an ndarray. Returns [`Option::None`]
/// if the element type & value are faulty and should panic with [`unsupported_type`].
fn helper_call_numpy<'ctx, OnElemFn, RetElemFn, G: CodeGenerator + ?Sized>(
/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
(arg_ty, arg_val): (Type, BasicValueEnum<'ctx>),
fn_name: &str,
get_ret_elem_type: &RetElemFn,
on_elem: &OnElemFn,
on_scalar: &OnScalarFn,
) -> Result<BasicValueEnum<'ctx>, String>
where
OnElemFn: Fn(
G: CodeGenerator + ?Sized,
OnScalarFn: Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
Type,
@ -1085,20 +1085,20 @@ where
None,
NDArrayValue::from_ptr_val(x, llvm_usize, None),
|generator, ctx, elem_val| {
helper_call_numpy(
helper_call_numpy_unary_elementwise(
generator,
ctx,
(arg_elem_ty, elem_val),
fn_name,
get_ret_elem_type,
on_elem,
on_scalar,
)
},
)?;
ndarray.as_base_value().into()
}
_ => on_elem(generator, ctx, arg_ty, arg_val)
_ => on_scalar(generator, ctx, arg_ty, arg_val)
.unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])),
};
@ -1110,12 +1110,12 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let fn_name: &str = "abs";
helper_call_numpy(
const FN_NAME: &str = "abs";
helper_call_numpy_unary_elementwise(
generator,
ctx,
arg,
fn_name,
FN_NAME,
&|_ctx, elem_ty| elem_ty,
&|_generator, ctx, val_ty, val| match val {
BasicValueEnum::IntValue(n) => Some({
@ -1137,7 +1137,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
ctx,
n,
ctx.ctx.bool_type().const_zero(),
Some(fn_name),
Some(FN_NAME),
)
.into()
} else {
@ -1148,7 +1148,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::FloatValue(n) => Some({
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float));
llvm_intrinsics::call_float_fabs(ctx, n, Some(fn_name)).into()
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
}),
_ => None,
@ -1156,40 +1156,88 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
)
}
/// Helper macro. Used so we don't have to keep typing out the type signature of the function.
macro_rules! call_numpy {
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_elem:expr) => {
/// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`].
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
macro_rules! create_helper_call_numpy_unary_elementwise {
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => {
#[allow(clippy::redundant_closure_call)]
pub fn $name<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
helper_call_numpy(generator, ctx, arg, $fn_name, &$get_ret_elem_type, &$on_elem)
helper_call_numpy_unary_elementwise(
generator,
ctx,
arg,
$fn_name,
&$get_ret_elem_type,
&$on_scalar,
)
}
};
}
/// Helper macro, only used by [`call_numpy_isnan`] and [`call_numpy_isinf`]
macro_rules! call_numpy_ret_bool {
($name:ident, $fn_name:literal, $elem_call:expr) => {
call_numpy!($name, $fn_name, |ctx, _| ctx.primitives.bool, |generator, ctx, n_ty, val| {
match val {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise.
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
/// ```rust
/// // Type of `$on_scalar:expr`
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
/// generator: &mut G,
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// arg: FloatValue<'ctx>
/// ) -> IntValue<'ctx> // of LLVM type `i1`
/// ```
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_bool {
($name:ident, $fn_name:literal, $on_scalar:expr) => {
create_helper_call_numpy_unary_elementwise!(
$name,
$fn_name,
|ctx, _| ctx.primitives.bool,
|generator, ctx, n_ty, val| {
match val {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
let ret = $elem_call(generator, ctx, n);
Some(generator.bool_to_i8(ctx, ret).into())
let ret = $on_scalar(generator, ctx, n);
Some(generator.bool_to_i8(ctx, ret).into())
}
_ => None,
}
_ => None,
}
});
);
};
}
macro_rules! call_numpy_float {
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise.
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results.
/// ```rust
/// // Type of `$on_scalar:expr`
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
/// generator: &mut G,
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// arg: FloatValue<'ctx>
/// ) -> FloatValue<'ctx>
/// ```
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_float {
($name:ident, $fn_name:literal, $elem_call:expr) => {
call_numpy!(
create_helper_call_numpy_unary_elementwise!(
$name,
$fn_name,
|ctx, _| ctx.primitives.float,
@ -1207,49 +1255,165 @@ macro_rules! call_numpy_float {
};
}
call_numpy_ret_bool!(call_numpy_isnan, "np_isnan", irrt::call_isnan);
call_numpy_ret_bool!(call_numpy_isinf, "np_isinf", irrt::call_isinf);
create_helper_call_numpy_unary_elementwise_float_to_bool!(
call_numpy_isnan,
"np_isnan",
irrt::call_isnan
);
create_helper_call_numpy_unary_elementwise_float_to_bool!(
call_numpy_isinf,
"np_isinf",
irrt::call_isinf
);
call_numpy_float!(call_numpy_sin, "np_sin", llvm_intrinsics::call_float_sin);
call_numpy_float!(call_numpy_cos, "np_cos", llvm_intrinsics::call_float_cos);
call_numpy_float!(call_numpy_tan, "np_tan", extern_fns::call_tan);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_sin,
"np_sin",
llvm_intrinsics::call_float_sin
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_cos,
"np_cos",
llvm_intrinsics::call_float_cos
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_tan,
"np_tan",
extern_fns::call_tan
);
call_numpy_float!(call_numpy_arcsin, "np_arcsin", extern_fns::call_asin);
call_numpy_float!(call_numpy_arccos, "np_arccos", extern_fns::call_acos);
call_numpy_float!(call_numpy_arctan, "np_arctan", extern_fns::call_atan);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arcsin,
"np_arcsin",
extern_fns::call_asin
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arccos,
"np_arccos",
extern_fns::call_acos
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arctan,
"np_arctan",
extern_fns::call_atan
);
call_numpy_float!(call_numpy_sinh, "np_sinh", extern_fns::call_sinh);
call_numpy_float!(call_numpy_cosh, "np_cosh", extern_fns::call_cosh);
call_numpy_float!(call_numpy_tanh, "np_tanh", extern_fns::call_tanh);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_sinh,
"np_sinh",
extern_fns::call_sinh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_cosh,
"np_cosh",
extern_fns::call_cosh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_tanh,
"np_tanh",
extern_fns::call_tanh
);
call_numpy_float!(call_numpy_arcsinh, "np_arcsinh", extern_fns::call_asinh);
call_numpy_float!(call_numpy_arccosh, "np_arccosh", extern_fns::call_acosh);
call_numpy_float!(call_numpy_arctanh, "np_arctanh", extern_fns::call_atanh);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arcsinh,
"np_arcsinh",
extern_fns::call_asinh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arccosh,
"np_arccosh",
extern_fns::call_acosh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arctanh,
"np_arctanh",
extern_fns::call_atanh
);
call_numpy_float!(call_numpy_exp, "np_exp", llvm_intrinsics::call_float_exp);
call_numpy_float!(call_numpy_exp2, "np_exp2", llvm_intrinsics::call_float_exp2);
call_numpy_float!(call_numpy_expm1, "np_expm1", extern_fns::call_expm1);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_exp,
"np_exp",
llvm_intrinsics::call_float_exp
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_exp2,
"np_exp2",
llvm_intrinsics::call_float_exp2
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_expm1,
"np_expm1",
extern_fns::call_expm1
);
call_numpy_float!(call_numpy_log, "np_log", llvm_intrinsics::call_float_log);
call_numpy_float!(call_numpy_log2, "np_log2", llvm_intrinsics::call_float_log2);
call_numpy_float!(call_numpy_log10, "np_log10", llvm_intrinsics::call_float_log10);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_log,
"np_log",
llvm_intrinsics::call_float_log
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_log2,
"np_log2",
llvm_intrinsics::call_float_log2
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_log10,
"np_log10",
llvm_intrinsics::call_float_log10
);
call_numpy_float!(call_numpy_sqrt, "np_sqrt", llvm_intrinsics::call_float_sqrt);
call_numpy_float!(call_numpy_cbrt, "np_cbrt", extern_fns::call_cbrt);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_sqrt,
"np_sqrt",
llvm_intrinsics::call_float_sqrt
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_cbrt,
"np_cbrt",
extern_fns::call_cbrt
);
call_numpy_float!(call_numpy_fabs, "np_fabs", llvm_intrinsics::call_float_fabs);
call_numpy_float!(call_numpy_rint, "np_rint", llvm_intrinsics::call_float_roundeven);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_fabs,
"np_fabs",
llvm_intrinsics::call_float_fabs
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_rint,
"np_rint",
llvm_intrinsics::call_float_roundeven
);
call_numpy_float!(call_scipy_special_erf, "sp_spec_erf", extern_fns::call_erf);
call_numpy_float!(call_scipy_special_erfc, "sp_spec_erfc", extern_fns::call_erfc);
call_numpy_float!(call_scipy_special_gamma, "sp_spec_gamma", |ctx, val, _| irrt::call_gamma(
ctx, val
));
call_numpy_float!(call_scipy_special_gammaln, "sp_spec_gammaln", |ctx, val, _| irrt::call_gammaln(
ctx, val
));
call_numpy_float!(call_scipy_special_j0, "sp_spec_j0", |ctx, val, _| irrt::call_j0(ctx, val));
call_numpy_float!(call_scipy_special_j1, "sp_spec_j1", extern_fns::call_j1);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_erf,
"sp_spec_erf",
extern_fns::call_erf
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_erfc,
"sp_spec_erfc",
extern_fns::call_erfc
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_gamma,
"sp_spec_gamma",
|ctx, val, _| irrt::call_gamma(ctx, val)
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_gammaln,
"sp_spec_gammaln",
|ctx, val, _| irrt::call_gammaln(ctx, val)
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_j0,
"sp_spec_j0",
|ctx, val, _| irrt::call_j0(ctx, val)
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_j1,
"sp_spec_j1",
extern_fns::call_j1
);
/// Invokes the `np_arctan2` builtin function.
pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(