From 777077a72320d44c837971b3c435b4e44830d0b7 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 20 Jun 2024 10:26:15 +0800 Subject: [PATCH] core: revise nac3core/codegen/builtin_fns.rs numpy macros --- nac3core/src/codegen/builtin_fns.rs | 298 +++++++++++++++++++++------- 1 file changed, 231 insertions(+), 67 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a56da144..7ab695c4 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1044,25 +1044,25 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( }) } -/// Helper function to create a builtin that takes in either an ndarray or a value and returns a value of the same structure. -/// (e.g, `float` to `float`, `float` to `int`, `ndarray` to `ndarray`, or even `ndarray` to `ndarray`). +/// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar. /// /// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument. /// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`] -/// * `get_ret_elem_type`: A function that takes in the input element [`Type`], and returns the correct return [`Type`]. +/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`]. /// Return a constant [`Type`] here if the return type does not depend on the input type. -/// * `on_elem`: The function to be called when the input argument is not an ndarray. Returns [`Option::None`] -/// if the element type & value are faulty and should panic with [`unsupported_type`]. -fn helper_call_numpy<'ctx, OnElemFn, RetElemFn, G: CodeGenerator + ?Sized>( +/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`] +/// if the scalar type & value are faulty and should panic with [`unsupported_type`]. +fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, (arg_ty, arg_val): (Type, BasicValueEnum<'ctx>), fn_name: &str, get_ret_elem_type: &RetElemFn, - on_elem: &OnElemFn, + on_scalar: &OnScalarFn, ) -> Result, String> where - OnElemFn: Fn( + G: CodeGenerator + ?Sized, + OnScalarFn: Fn( &mut G, &mut CodeGenContext<'ctx, '_>, Type, @@ -1085,20 +1085,20 @@ where None, NDArrayValue::from_ptr_val(x, llvm_usize, None), |generator, ctx, elem_val| { - helper_call_numpy( + helper_call_numpy_unary_elementwise( generator, ctx, (arg_elem_ty, elem_val), fn_name, get_ret_elem_type, - on_elem, + on_scalar, ) }, )?; ndarray.as_base_value().into() } - _ => on_elem(generator, ctx, arg_ty, arg_val) + _ => on_scalar(generator, ctx, arg_ty, arg_val) .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), }; @@ -1110,12 +1110,12 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, arg: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let fn_name: &str = "abs"; - helper_call_numpy( + const FN_NAME: &str = "abs"; + helper_call_numpy_unary_elementwise( generator, ctx, arg, - fn_name, + FN_NAME, &|_ctx, elem_ty| elem_ty, &|_generator, ctx, val_ty, val| match val { BasicValueEnum::IntValue(n) => Some({ @@ -1137,7 +1137,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( ctx, n, ctx.ctx.bool_type().const_zero(), - Some(fn_name), + Some(FN_NAME), ) .into() } else { @@ -1148,7 +1148,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::FloatValue(n) => Some({ debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); - llvm_intrinsics::call_float_fabs(ctx, n, Some(fn_name)).into() + llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() }), _ => None, @@ -1156,40 +1156,88 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( ) } -/// Helper macro. Used so we don't have to keep typing out the type signature of the function. -macro_rules! call_numpy { - ($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_elem:expr) => { +/// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`]. +/// +/// Arguments: +/// * `$name:ident`: The identifier of the rust function to be generated. +/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`] +/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`]. +/// But there is no need to make it a reference. +/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`]. +/// But there is no need to make it a reference. +macro_rules! create_helper_call_numpy_unary_elementwise { + ($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => { #[allow(clippy::redundant_closure_call)] pub fn $name<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, arg: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - helper_call_numpy(generator, ctx, arg, $fn_name, &$get_ret_elem_type, &$on_elem) + helper_call_numpy_unary_elementwise( + generator, + ctx, + arg, + $fn_name, + &$get_ret_elem_type, + &$on_scalar, + ) } }; } -/// Helper macro, only used by [`call_numpy_isnan`] and [`call_numpy_isinf`] -macro_rules! call_numpy_ret_bool { - ($name:ident, $fn_name:literal, $elem_call:expr) => { - call_numpy!($name, $fn_name, |ctx, _| ctx.primitives.bool, |generator, ctx, n_ty, val| { - match val { - BasicValueEnum::FloatValue(n) => { - debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); +/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise. +/// +/// Arguments: +/// * `$name:ident`: The identifier of the rust function to be generated. +/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]. +/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns +/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`. +/// ```rust +/// // Type of `$on_scalar:expr` +/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>( +/// generator: &mut G, +/// ctx: &mut CodeGenContext<'ctx, '_>, +/// arg: FloatValue<'ctx> +/// ) -> IntValue<'ctx> // of LLVM type `i1` +/// ``` +macro_rules! create_helper_call_numpy_unary_elementwise_float_to_bool { + ($name:ident, $fn_name:literal, $on_scalar:expr) => { + create_helper_call_numpy_unary_elementwise!( + $name, + $fn_name, + |ctx, _| ctx.primitives.bool, + |generator, ctx, n_ty, val| { + match val { + BasicValueEnum::FloatValue(n) => { + debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); - let ret = $elem_call(generator, ctx, n); - Some(generator.bool_to_i8(ctx, ret).into()) + let ret = $on_scalar(generator, ctx, n); + Some(generator.bool_to_i8(ctx, ret).into()) + } + _ => None, } - _ => None, } - }); + ); }; } -macro_rules! call_numpy_float { +/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise. +/// +/// Arguments: +/// * `$name:ident`: The identifier of the rust function to be generated. +/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]. +/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results. +/// ```rust +/// // Type of `$on_scalar:expr` +/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>( +/// generator: &mut G, +/// ctx: &mut CodeGenContext<'ctx, '_>, +/// arg: FloatValue<'ctx> +/// ) -> FloatValue<'ctx> +/// ``` +macro_rules! create_helper_call_numpy_unary_elementwise_float_to_float { ($name:ident, $fn_name:literal, $elem_call:expr) => { - call_numpy!( + create_helper_call_numpy_unary_elementwise!( $name, $fn_name, |ctx, _| ctx.primitives.float, @@ -1207,49 +1255,165 @@ macro_rules! call_numpy_float { }; } -call_numpy_ret_bool!(call_numpy_isnan, "np_isnan", irrt::call_isnan); -call_numpy_ret_bool!(call_numpy_isinf, "np_isinf", irrt::call_isinf); +create_helper_call_numpy_unary_elementwise_float_to_bool!( + call_numpy_isnan, + "np_isnan", + irrt::call_isnan +); +create_helper_call_numpy_unary_elementwise_float_to_bool!( + call_numpy_isinf, + "np_isinf", + irrt::call_isinf +); -call_numpy_float!(call_numpy_sin, "np_sin", llvm_intrinsics::call_float_sin); -call_numpy_float!(call_numpy_cos, "np_cos", llvm_intrinsics::call_float_cos); -call_numpy_float!(call_numpy_tan, "np_tan", extern_fns::call_tan); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_sin, + "np_sin", + llvm_intrinsics::call_float_sin +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_cos, + "np_cos", + llvm_intrinsics::call_float_cos +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_tan, + "np_tan", + extern_fns::call_tan +); -call_numpy_float!(call_numpy_arcsin, "np_arcsin", extern_fns::call_asin); -call_numpy_float!(call_numpy_arccos, "np_arccos", extern_fns::call_acos); -call_numpy_float!(call_numpy_arctan, "np_arctan", extern_fns::call_atan); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_arcsin, + "np_arcsin", + extern_fns::call_asin +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_arccos, + "np_arccos", + extern_fns::call_acos +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_arctan, + "np_arctan", + extern_fns::call_atan +); -call_numpy_float!(call_numpy_sinh, "np_sinh", extern_fns::call_sinh); -call_numpy_float!(call_numpy_cosh, "np_cosh", extern_fns::call_cosh); -call_numpy_float!(call_numpy_tanh, "np_tanh", extern_fns::call_tanh); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_sinh, + "np_sinh", + extern_fns::call_sinh +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_cosh, + "np_cosh", + extern_fns::call_cosh +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_tanh, + "np_tanh", + extern_fns::call_tanh +); -call_numpy_float!(call_numpy_arcsinh, "np_arcsinh", extern_fns::call_asinh); -call_numpy_float!(call_numpy_arccosh, "np_arccosh", extern_fns::call_acosh); -call_numpy_float!(call_numpy_arctanh, "np_arctanh", extern_fns::call_atanh); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_arcsinh, + "np_arcsinh", + extern_fns::call_asinh +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_arccosh, + "np_arccosh", + extern_fns::call_acosh +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_arctanh, + "np_arctanh", + extern_fns::call_atanh +); -call_numpy_float!(call_numpy_exp, "np_exp", llvm_intrinsics::call_float_exp); -call_numpy_float!(call_numpy_exp2, "np_exp2", llvm_intrinsics::call_float_exp2); -call_numpy_float!(call_numpy_expm1, "np_expm1", extern_fns::call_expm1); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_exp, + "np_exp", + llvm_intrinsics::call_float_exp +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_exp2, + "np_exp2", + llvm_intrinsics::call_float_exp2 +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_expm1, + "np_expm1", + extern_fns::call_expm1 +); -call_numpy_float!(call_numpy_log, "np_log", llvm_intrinsics::call_float_log); -call_numpy_float!(call_numpy_log2, "np_log2", llvm_intrinsics::call_float_log2); -call_numpy_float!(call_numpy_log10, "np_log10", llvm_intrinsics::call_float_log10); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_log, + "np_log", + llvm_intrinsics::call_float_log +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_log2, + "np_log2", + llvm_intrinsics::call_float_log2 +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_log10, + "np_log10", + llvm_intrinsics::call_float_log10 +); -call_numpy_float!(call_numpy_sqrt, "np_sqrt", llvm_intrinsics::call_float_sqrt); -call_numpy_float!(call_numpy_cbrt, "np_cbrt", extern_fns::call_cbrt); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_sqrt, + "np_sqrt", + llvm_intrinsics::call_float_sqrt +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_cbrt, + "np_cbrt", + extern_fns::call_cbrt +); -call_numpy_float!(call_numpy_fabs, "np_fabs", llvm_intrinsics::call_float_fabs); -call_numpy_float!(call_numpy_rint, "np_rint", llvm_intrinsics::call_float_roundeven); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_fabs, + "np_fabs", + llvm_intrinsics::call_float_fabs +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_numpy_rint, + "np_rint", + llvm_intrinsics::call_float_roundeven +); -call_numpy_float!(call_scipy_special_erf, "sp_spec_erf", extern_fns::call_erf); -call_numpy_float!(call_scipy_special_erfc, "sp_spec_erfc", extern_fns::call_erfc); -call_numpy_float!(call_scipy_special_gamma, "sp_spec_gamma", |ctx, val, _| irrt::call_gamma( - ctx, val -)); -call_numpy_float!(call_scipy_special_gammaln, "sp_spec_gammaln", |ctx, val, _| irrt::call_gammaln( - ctx, val -)); -call_numpy_float!(call_scipy_special_j0, "sp_spec_j0", |ctx, val, _| irrt::call_j0(ctx, val)); -call_numpy_float!(call_scipy_special_j1, "sp_spec_j1", extern_fns::call_j1); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_scipy_special_erf, + "sp_spec_erf", + extern_fns::call_erf +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_scipy_special_erfc, + "sp_spec_erfc", + extern_fns::call_erfc +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_scipy_special_gamma, + "sp_spec_gamma", + |ctx, val, _| irrt::call_gamma(ctx, val) +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_scipy_special_gammaln, + "sp_spec_gammaln", + |ctx, val, _| irrt::call_gammaln(ctx, val) +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_scipy_special_j0, + "sp_spec_j0", + |ctx, val, _| irrt::call_j0(ctx, val) +); +create_helper_call_numpy_unary_elementwise_float_to_float!( + call_scipy_special_j1, + "sp_spec_j1", + extern_fns::call_j1 +); /// Invokes the `np_arctan2` builtin function. pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(