core: reduce code duplication in codegen/builtin_fns #422
@ -1044,25 +1044,25 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to create a builtin that takes in either an ndarray or a value and returns a value of the same structure.
|
||||
/// (e.g, `float` to `float`, `float` to `int`, `ndarray<float>` to `ndarray<float>`, or even `ndarray<float>` to `ndarray<int>`).
|
||||
/// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar.
|
||||
///
|
||||
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
|
||||
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
|
||||
/// * `get_ret_elem_type`: A function that takes in the input element [`Type`], and returns the correct return [`Type`].
|
||||
/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
|
||||
/// Return a constant [`Type`] here if the return type does not depend on the input type.
|
||||
/// * `on_elem`: The function to be called when the input argument is not an ndarray. Returns [`Option::None`]
|
||||
/// if the element type & value are faulty and should panic with [`unsupported_type`].
|
||||
fn helper_call_numpy<'ctx, OnElemFn, RetElemFn, G: CodeGenerator + ?Sized>(
|
||||
/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
|
||||
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
|
||||
fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(arg_ty, arg_val): (Type, BasicValueEnum<'ctx>),
|
||||
fn_name: &str,
|
||||
get_ret_elem_type: &RetElemFn,
|
||||
on_elem: &OnElemFn,
|
||||
on_scalar: &OnScalarFn,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>
|
||||
derppening marked this conversation as resolved
|
||||
where
|
||||
OnElemFn: Fn(
|
||||
G: CodeGenerator + ?Sized,
|
||||
OnScalarFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, '_>,
|
||||
Type,
|
||||
@ -1085,20 +1085,20 @@ where
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(x, llvm_usize, None),
|
||||
|generator, ctx, elem_val| {
|
||||
helper_call_numpy(
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
(arg_elem_ty, elem_val),
|
||||
fn_name,
|
||||
get_ret_elem_type,
|
||||
on_elem,
|
||||
on_scalar,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
ndarray.as_base_value().into()
|
||||
}
|
||||
|
||||
_ => on_elem(generator, ctx, arg_ty, arg_val)
|
||||
_ => on_scalar(generator, ctx, arg_ty, arg_val)
|
||||
.unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])),
|
||||
};
|
||||
|
||||
@ -1110,12 +1110,12 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
arg: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let fn_name: &str = "abs";
|
||||
helper_call_numpy(
|
||||
const FN_NAME: &str = "abs";
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
arg,
|
||||
fn_name,
|
||||
FN_NAME,
|
||||
&|_ctx, elem_ty| elem_ty,
|
||||
&|_generator, ctx, val_ty, val| match val {
|
||||
BasicValueEnum::IntValue(n) => Some({
|
||||
@ -1137,7 +1137,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx,
|
||||
n,
|
||||
ctx.ctx.bool_type().const_zero(),
|
||||
Some(fn_name),
|
||||
Some(FN_NAME),
|
||||
)
|
||||
.into()
|
||||
} else {
|
||||
@ -1148,7 +1148,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::FloatValue(n) => Some({
|
||||
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float));
|
||||
|
||||
llvm_intrinsics::call_float_fabs(ctx, n, Some(fn_name)).into()
|
||||
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
|
||||
}),
|
||||
|
||||
_ => None,
|
||||
@ -1156,40 +1156,88 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper macro. Used so we don't have to keep typing out the type signature of the function.
|
||||
macro_rules! call_numpy {
|
||||
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_elem:expr) => {
|
||||
/// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`].
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
|
||||
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// But there is no need to make it a reference.
|
||||
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// But there is no need to make it a reference.
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise {
|
||||
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
pub fn $name<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
arg: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
helper_call_numpy(generator, ctx, arg, $fn_name, &$get_ret_elem_type, &$on_elem)
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
arg,
|
||||
$fn_name,
|
||||
&$get_ret_elem_type,
|
||||
&$on_scalar,
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Helper macro, only used by [`call_numpy_isnan`] and [`call_numpy_isinf`]
|
||||
macro_rules! call_numpy_ret_bool {
|
||||
($name:ident, $fn_name:literal, $elem_call:expr) => {
|
||||
call_numpy!($name, $fn_name, |ctx, _| ctx.primitives.bool, |generator, ctx, n_ty, val| {
|
||||
match val {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise.
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
|
||||
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
|
||||
/// ```rust
|
||||
/// // Type of `$on_scalar:expr`
|
||||
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// generator: &mut G,
|
||||
/// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
/// arg: FloatValue<'ctx>
|
||||
/// ) -> IntValue<'ctx> // of LLVM type `i1`
|
||||
/// ```
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_bool {
|
||||
($name:ident, $fn_name:literal, $on_scalar:expr) => {
|
||||
create_helper_call_numpy_unary_elementwise!(
|
||||
$name,
|
||||
$fn_name,
|
||||
|ctx, _| ctx.primitives.bool,
|
||||
|generator, ctx, n_ty, val| {
|
||||
match val {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
|
||||
let ret = $elem_call(generator, ctx, n);
|
||||
Some(generator.bool_to_i8(ctx, ret).into())
|
||||
let ret = $on_scalar(generator, ctx, n);
|
||||
Some(generator.bool_to_i8(ctx, ret).into())
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! call_numpy_float {
|
||||
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise.
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results.
|
||||
/// ```rust
|
||||
/// // Type of `$on_scalar:expr`
|
||||
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// generator: &mut G,
|
||||
/// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
/// arg: FloatValue<'ctx>
|
||||
/// ) -> FloatValue<'ctx>
|
||||
/// ```
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_float {
|
||||
($name:ident, $fn_name:literal, $elem_call:expr) => {
|
||||
call_numpy!(
|
||||
create_helper_call_numpy_unary_elementwise!(
|
||||
$name,
|
||||
$fn_name,
|
||||
|ctx, _| ctx.primitives.float,
|
||||
@ -1207,49 +1255,165 @@ macro_rules! call_numpy_float {
|
||||
};
|
||||
}
|
||||
|
||||
call_numpy_ret_bool!(call_numpy_isnan, "np_isnan", irrt::call_isnan);
|
||||
call_numpy_ret_bool!(call_numpy_isinf, "np_isinf", irrt::call_isinf);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_bool!(
|
||||
call_numpy_isnan,
|
||||
"np_isnan",
|
||||
irrt::call_isnan
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_bool!(
|
||||
call_numpy_isinf,
|
||||
"np_isinf",
|
||||
irrt::call_isinf
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_sin, "np_sin", llvm_intrinsics::call_float_sin);
|
||||
call_numpy_float!(call_numpy_cos, "np_cos", llvm_intrinsics::call_float_cos);
|
||||
call_numpy_float!(call_numpy_tan, "np_tan", extern_fns::call_tan);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_sin,
|
||||
"np_sin",
|
||||
llvm_intrinsics::call_float_sin
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_cos,
|
||||
"np_cos",
|
||||
llvm_intrinsics::call_float_cos
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_tan,
|
||||
"np_tan",
|
||||
extern_fns::call_tan
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_arcsin, "np_arcsin", extern_fns::call_asin);
|
||||
call_numpy_float!(call_numpy_arccos, "np_arccos", extern_fns::call_acos);
|
||||
call_numpy_float!(call_numpy_arctan, "np_arctan", extern_fns::call_atan);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arcsin,
|
||||
"np_arcsin",
|
||||
extern_fns::call_asin
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arccos,
|
||||
"np_arccos",
|
||||
extern_fns::call_acos
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arctan,
|
||||
"np_arctan",
|
||||
extern_fns::call_atan
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_sinh, "np_sinh", extern_fns::call_sinh);
|
||||
call_numpy_float!(call_numpy_cosh, "np_cosh", extern_fns::call_cosh);
|
||||
call_numpy_float!(call_numpy_tanh, "np_tanh", extern_fns::call_tanh);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_sinh,
|
||||
"np_sinh",
|
||||
extern_fns::call_sinh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_cosh,
|
||||
"np_cosh",
|
||||
extern_fns::call_cosh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_tanh,
|
||||
"np_tanh",
|
||||
extern_fns::call_tanh
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_arcsinh, "np_arcsinh", extern_fns::call_asinh);
|
||||
call_numpy_float!(call_numpy_arccosh, "np_arccosh", extern_fns::call_acosh);
|
||||
call_numpy_float!(call_numpy_arctanh, "np_arctanh", extern_fns::call_atanh);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arcsinh,
|
||||
"np_arcsinh",
|
||||
extern_fns::call_asinh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arccosh,
|
||||
"np_arccosh",
|
||||
extern_fns::call_acosh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arctanh,
|
||||
"np_arctanh",
|
||||
extern_fns::call_atanh
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_exp, "np_exp", llvm_intrinsics::call_float_exp);
|
||||
call_numpy_float!(call_numpy_exp2, "np_exp2", llvm_intrinsics::call_float_exp2);
|
||||
call_numpy_float!(call_numpy_expm1, "np_expm1", extern_fns::call_expm1);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_exp,
|
||||
"np_exp",
|
||||
llvm_intrinsics::call_float_exp
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_exp2,
|
||||
"np_exp2",
|
||||
llvm_intrinsics::call_float_exp2
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_expm1,
|
||||
"np_expm1",
|
||||
extern_fns::call_expm1
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_log, "np_log", llvm_intrinsics::call_float_log);
|
||||
call_numpy_float!(call_numpy_log2, "np_log2", llvm_intrinsics::call_float_log2);
|
||||
call_numpy_float!(call_numpy_log10, "np_log10", llvm_intrinsics::call_float_log10);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_log,
|
||||
"np_log",
|
||||
llvm_intrinsics::call_float_log
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_log2,
|
||||
"np_log2",
|
||||
llvm_intrinsics::call_float_log2
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_log10,
|
||||
"np_log10",
|
||||
llvm_intrinsics::call_float_log10
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_sqrt, "np_sqrt", llvm_intrinsics::call_float_sqrt);
|
||||
call_numpy_float!(call_numpy_cbrt, "np_cbrt", extern_fns::call_cbrt);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_sqrt,
|
||||
"np_sqrt",
|
||||
llvm_intrinsics::call_float_sqrt
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_cbrt,
|
||||
"np_cbrt",
|
||||
extern_fns::call_cbrt
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_fabs, "np_fabs", llvm_intrinsics::call_float_fabs);
|
||||
call_numpy_float!(call_numpy_rint, "np_rint", llvm_intrinsics::call_float_roundeven);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_fabs,
|
||||
"np_fabs",
|
||||
llvm_intrinsics::call_float_fabs
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_rint,
|
||||
"np_rint",
|
||||
llvm_intrinsics::call_float_roundeven
|
||||
);
|
||||
|
||||
call_numpy_float!(call_scipy_special_erf, "sp_spec_erf", extern_fns::call_erf);
|
||||
call_numpy_float!(call_scipy_special_erfc, "sp_spec_erfc", extern_fns::call_erfc);
|
||||
call_numpy_float!(call_scipy_special_gamma, "sp_spec_gamma", |ctx, val, _| irrt::call_gamma(
|
||||
ctx, val
|
||||
));
|
||||
call_numpy_float!(call_scipy_special_gammaln, "sp_spec_gammaln", |ctx, val, _| irrt::call_gammaln(
|
||||
ctx, val
|
||||
));
|
||||
call_numpy_float!(call_scipy_special_j0, "sp_spec_j0", |ctx, val, _| irrt::call_j0(ctx, val));
|
||||
call_numpy_float!(call_scipy_special_j1, "sp_spec_j1", extern_fns::call_j1);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_erf,
|
||||
"sp_spec_erf",
|
||||
extern_fns::call_erf
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_erfc,
|
||||
"sp_spec_erfc",
|
||||
extern_fns::call_erfc
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_gamma,
|
||||
"sp_spec_gamma",
|
||||
|ctx, val, _| irrt::call_gamma(ctx, val)
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_gammaln,
|
||||
"sp_spec_gammaln",
|
||||
|ctx, val, _| irrt::call_gammaln(ctx, val)
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_j0,
|
||||
"sp_spec_j0",
|
||||
|ctx, val, _| irrt::call_j0(ctx, val)
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_j1,
|
||||
"sp_spec_j1",
|
||||
extern_fns::call_j1
|
||||
);
|
||||
|
||||
/// Invokes the `np_arctan2` builtin function.
|
||||
pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
Loading…
Reference in New Issue
Block a user
on_scalar