core: revise nac3core/codegen/builtin_fns.rs numpy macros

This commit is contained in:
lyken 2024-06-20 10:26:15 +08:00
parent ef2502e7b4
commit 777077a723
1 changed files with 231 additions and 67 deletions

View File

@ -1044,25 +1044,25 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
/// Helper function to create a builtin that takes in either an ndarray or a value and returns a value of the same structure. /// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar.
/// (e.g, `float` to `float`, `float` to `int`, `ndarray<float>` to `ndarray<float>`, or even `ndarray<float>` to `ndarray<int>`).
/// ///
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument. /// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`] /// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
/// * `get_ret_elem_type`: A function that takes in the input element [`Type`], and returns the correct return [`Type`]. /// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
/// Return a constant [`Type`] here if the return type does not depend on the input type. /// Return a constant [`Type`] here if the return type does not depend on the input type.
/// * `on_elem`: The function to be called when the input argument is not an ndarray. Returns [`Option::None`] /// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
/// if the element type & value are faulty and should panic with [`unsupported_type`]. /// if the scalar type & value are faulty and should panic with [`unsupported_type`].
fn helper_call_numpy<'ctx, OnElemFn, RetElemFn, G: CodeGenerator + ?Sized>( fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
(arg_ty, arg_val): (Type, BasicValueEnum<'ctx>), (arg_ty, arg_val): (Type, BasicValueEnum<'ctx>),
fn_name: &str, fn_name: &str,
get_ret_elem_type: &RetElemFn, get_ret_elem_type: &RetElemFn,
on_elem: &OnElemFn, on_scalar: &OnScalarFn,
) -> Result<BasicValueEnum<'ctx>, String> ) -> Result<BasicValueEnum<'ctx>, String>
where where
OnElemFn: Fn( G: CodeGenerator + ?Sized,
OnScalarFn: Fn(
&mut G, &mut G,
&mut CodeGenContext<'ctx, '_>, &mut CodeGenContext<'ctx, '_>,
Type, Type,
@ -1085,20 +1085,20 @@ where
None, None,
NDArrayValue::from_ptr_val(x, llvm_usize, None), NDArrayValue::from_ptr_val(x, llvm_usize, None),
|generator, ctx, elem_val| { |generator, ctx, elem_val| {
helper_call_numpy( helper_call_numpy_unary_elementwise(
generator, generator,
ctx, ctx,
(arg_elem_ty, elem_val), (arg_elem_ty, elem_val),
fn_name, fn_name,
get_ret_elem_type, get_ret_elem_type,
on_elem, on_scalar,
) )
}, },
)?; )?;
ndarray.as_base_value().into() ndarray.as_base_value().into()
} }
_ => on_elem(generator, ctx, arg_ty, arg_val) _ => on_scalar(generator, ctx, arg_ty, arg_val)
.unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])),
}; };
@ -1110,12 +1110,12 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>), arg: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
let fn_name: &str = "abs"; const FN_NAME: &str = "abs";
helper_call_numpy( helper_call_numpy_unary_elementwise(
generator, generator,
ctx, ctx,
arg, arg,
fn_name, FN_NAME,
&|_ctx, elem_ty| elem_ty, &|_ctx, elem_ty| elem_ty,
&|_generator, ctx, val_ty, val| match val { &|_generator, ctx, val_ty, val| match val {
BasicValueEnum::IntValue(n) => Some({ BasicValueEnum::IntValue(n) => Some({
@ -1137,7 +1137,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
n, n,
ctx.ctx.bool_type().const_zero(), ctx.ctx.bool_type().const_zero(),
Some(fn_name), Some(FN_NAME),
) )
.into() .into()
} else { } else {
@ -1148,7 +1148,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::FloatValue(n) => Some({ BasicValueEnum::FloatValue(n) => Some({
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float));
llvm_intrinsics::call_float_fabs(ctx, n, Some(fn_name)).into() llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
}), }),
_ => None, _ => None,
@ -1156,40 +1156,88 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
) )
} }
/// Helper macro. Used so we don't have to keep typing out the type signature of the function. /// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`].
macro_rules! call_numpy { ///
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_elem:expr) => { /// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
macro_rules! create_helper_call_numpy_unary_elementwise {
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => {
#[allow(clippy::redundant_closure_call)] #[allow(clippy::redundant_closure_call)]
pub fn $name<'ctx, G: CodeGenerator + ?Sized>( pub fn $name<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
arg: (Type, BasicValueEnum<'ctx>), arg: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
helper_call_numpy(generator, ctx, arg, $fn_name, &$get_ret_elem_type, &$on_elem) helper_call_numpy_unary_elementwise(
generator,
ctx,
arg,
$fn_name,
&$get_ret_elem_type,
&$on_scalar,
)
} }
}; };
} }
/// Helper macro, only used by [`call_numpy_isnan`] and [`call_numpy_isinf`] /// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise.
macro_rules! call_numpy_ret_bool { ///
($name:ident, $fn_name:literal, $elem_call:expr) => { /// Arguments:
call_numpy!($name, $fn_name, |ctx, _| ctx.primitives.bool, |generator, ctx, n_ty, val| { /// * `$name:ident`: The identifier of the rust function to be generated.
match val { /// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
BasicValueEnum::FloatValue(n) => { /// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); /// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
/// ```rust
/// // Type of `$on_scalar:expr`
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
/// generator: &mut G,
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// arg: FloatValue<'ctx>
/// ) -> IntValue<'ctx> // of LLVM type `i1`
/// ```
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_bool {
($name:ident, $fn_name:literal, $on_scalar:expr) => {
create_helper_call_numpy_unary_elementwise!(
$name,
$fn_name,
|ctx, _| ctx.primitives.bool,
|generator, ctx, n_ty, val| {
match val {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
let ret = $elem_call(generator, ctx, n); let ret = $on_scalar(generator, ctx, n);
Some(generator.bool_to_i8(ctx, ret).into()) Some(generator.bool_to_i8(ctx, ret).into())
}
_ => None,
} }
_ => None,
} }
}); );
}; };
} }
macro_rules! call_numpy_float { /// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise.
///
/// Arguments:
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results.
/// ```rust
/// // Type of `$on_scalar:expr`
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
/// generator: &mut G,
/// ctx: &mut CodeGenContext<'ctx, '_>,
/// arg: FloatValue<'ctx>
/// ) -> FloatValue<'ctx>
/// ```
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_float {
($name:ident, $fn_name:literal, $elem_call:expr) => { ($name:ident, $fn_name:literal, $elem_call:expr) => {
call_numpy!( create_helper_call_numpy_unary_elementwise!(
$name, $name,
$fn_name, $fn_name,
|ctx, _| ctx.primitives.float, |ctx, _| ctx.primitives.float,
@ -1207,49 +1255,165 @@ macro_rules! call_numpy_float {
}; };
} }
call_numpy_ret_bool!(call_numpy_isnan, "np_isnan", irrt::call_isnan); create_helper_call_numpy_unary_elementwise_float_to_bool!(
call_numpy_ret_bool!(call_numpy_isinf, "np_isinf", irrt::call_isinf); call_numpy_isnan,
"np_isnan",
irrt::call_isnan
);
create_helper_call_numpy_unary_elementwise_float_to_bool!(
call_numpy_isinf,
"np_isinf",
irrt::call_isinf
);
call_numpy_float!(call_numpy_sin, "np_sin", llvm_intrinsics::call_float_sin); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_cos, "np_cos", llvm_intrinsics::call_float_cos); call_numpy_sin,
call_numpy_float!(call_numpy_tan, "np_tan", extern_fns::call_tan); "np_sin",
llvm_intrinsics::call_float_sin
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_cos,
"np_cos",
llvm_intrinsics::call_float_cos
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_tan,
"np_tan",
extern_fns::call_tan
);
call_numpy_float!(call_numpy_arcsin, "np_arcsin", extern_fns::call_asin); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_arccos, "np_arccos", extern_fns::call_acos); call_numpy_arcsin,
call_numpy_float!(call_numpy_arctan, "np_arctan", extern_fns::call_atan); "np_arcsin",
extern_fns::call_asin
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arccos,
"np_arccos",
extern_fns::call_acos
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arctan,
"np_arctan",
extern_fns::call_atan
);
call_numpy_float!(call_numpy_sinh, "np_sinh", extern_fns::call_sinh); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_cosh, "np_cosh", extern_fns::call_cosh); call_numpy_sinh,
call_numpy_float!(call_numpy_tanh, "np_tanh", extern_fns::call_tanh); "np_sinh",
extern_fns::call_sinh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_cosh,
"np_cosh",
extern_fns::call_cosh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_tanh,
"np_tanh",
extern_fns::call_tanh
);
call_numpy_float!(call_numpy_arcsinh, "np_arcsinh", extern_fns::call_asinh); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_arccosh, "np_arccosh", extern_fns::call_acosh); call_numpy_arcsinh,
call_numpy_float!(call_numpy_arctanh, "np_arctanh", extern_fns::call_atanh); "np_arcsinh",
extern_fns::call_asinh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arccosh,
"np_arccosh",
extern_fns::call_acosh
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_arctanh,
"np_arctanh",
extern_fns::call_atanh
);
call_numpy_float!(call_numpy_exp, "np_exp", llvm_intrinsics::call_float_exp); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_exp2, "np_exp2", llvm_intrinsics::call_float_exp2); call_numpy_exp,
call_numpy_float!(call_numpy_expm1, "np_expm1", extern_fns::call_expm1); "np_exp",
llvm_intrinsics::call_float_exp
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_exp2,
"np_exp2",
llvm_intrinsics::call_float_exp2
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_expm1,
"np_expm1",
extern_fns::call_expm1
);
call_numpy_float!(call_numpy_log, "np_log", llvm_intrinsics::call_float_log); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_log2, "np_log2", llvm_intrinsics::call_float_log2); call_numpy_log,
call_numpy_float!(call_numpy_log10, "np_log10", llvm_intrinsics::call_float_log10); "np_log",
llvm_intrinsics::call_float_log
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_log2,
"np_log2",
llvm_intrinsics::call_float_log2
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_log10,
"np_log10",
llvm_intrinsics::call_float_log10
);
call_numpy_float!(call_numpy_sqrt, "np_sqrt", llvm_intrinsics::call_float_sqrt); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_cbrt, "np_cbrt", extern_fns::call_cbrt); call_numpy_sqrt,
"np_sqrt",
llvm_intrinsics::call_float_sqrt
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_cbrt,
"np_cbrt",
extern_fns::call_cbrt
);
call_numpy_float!(call_numpy_fabs, "np_fabs", llvm_intrinsics::call_float_fabs); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_numpy_rint, "np_rint", llvm_intrinsics::call_float_roundeven); call_numpy_fabs,
"np_fabs",
llvm_intrinsics::call_float_fabs
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_rint,
"np_rint",
llvm_intrinsics::call_float_roundeven
);
call_numpy_float!(call_scipy_special_erf, "sp_spec_erf", extern_fns::call_erf); create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_float!(call_scipy_special_erfc, "sp_spec_erfc", extern_fns::call_erfc); call_scipy_special_erf,
call_numpy_float!(call_scipy_special_gamma, "sp_spec_gamma", |ctx, val, _| irrt::call_gamma( "sp_spec_erf",
ctx, val extern_fns::call_erf
)); );
call_numpy_float!(call_scipy_special_gammaln, "sp_spec_gammaln", |ctx, val, _| irrt::call_gammaln( create_helper_call_numpy_unary_elementwise_float_to_float!(
ctx, val call_scipy_special_erfc,
)); "sp_spec_erfc",
call_numpy_float!(call_scipy_special_j0, "sp_spec_j0", |ctx, val, _| irrt::call_j0(ctx, val)); extern_fns::call_erfc
call_numpy_float!(call_scipy_special_j1, "sp_spec_j1", extern_fns::call_j1); );
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_gamma,
"sp_spec_gamma",
|ctx, val, _| irrt::call_gamma(ctx, val)
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_gammaln,
"sp_spec_gammaln",
|ctx, val, _| irrt::call_gammaln(ctx, val)
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_j0,
"sp_spec_j0",
|ctx, val, _| irrt::call_j0(ctx, val)
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_scipy_special_j1,
"sp_spec_j1",
extern_fns::call_j1
);
/// Invokes the `np_arctan2` builtin function. /// Invokes the `np_arctan2` builtin function.
pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(