core: reduce code duplication in codegen/builtin_fns #422
@ -1044,25 +1044,25 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper function to create a builtin that takes in either an ndarray or a value and returns a value of the same structure.
|
||||
/// (e.g, `float` to `float`, `float` to `int`, `ndarray<float>` to `ndarray<float>`, or even `ndarray<float>` to `ndarray<int>`).
|
||||
/// Helper function to create a built-in elementwise unary numpy function that takes in either an ndarray or a scalar.
|
||||
///
|
||||
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
|
||||
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
|
||||
/// * `get_ret_elem_type`: A function that takes in the input element [`Type`], and returns the correct return [`Type`].
|
||||
/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
|
||||
/// Return a constant [`Type`] here if the return type does not depend on the input type.
|
||||
/// * `on_elem`: The function to be called when the input argument is not an ndarray. Returns [`Option::None`]
|
||||
/// if the element type & value are faulty and should panic with [`unsupported_type`].
|
||||
fn helper_call_numpy<'ctx, OnElemFn, RetElemFn, G: CodeGenerator + ?Sized>(
|
||||
/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
|
||||
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
|
||||
fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>(
|
||||
generator: &mut G,
|
||||
derppening marked this conversation as resolved
Outdated
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(arg_ty, arg_val): (Type, BasicValueEnum<'ctx>),
|
||||
fn_name: &str,
|
||||
get_ret_elem_type: &RetElemFn,
|
||||
on_elem: &OnElemFn,
|
||||
on_scalar: &OnScalarFn,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>
|
||||
derppening marked this conversation as resolved
derppening
commented
`on_scalar`
|
||||
where
|
||||
OnElemFn: Fn(
|
||||
G: CodeGenerator + ?Sized,
|
||||
OnScalarFn: Fn(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, '_>,
|
||||
Type,
|
||||
@ -1085,20 +1085,20 @@ where
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(x, llvm_usize, None),
|
||||
|generator, ctx, elem_val| {
|
||||
helper_call_numpy(
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
(arg_elem_ty, elem_val),
|
||||
fn_name,
|
||||
get_ret_elem_type,
|
||||
on_elem,
|
||||
on_scalar,
|
||||
)
|
||||
},
|
||||
)?;
|
||||
ndarray.as_base_value().into()
|
||||
}
|
||||
|
||||
_ => on_elem(generator, ctx, arg_ty, arg_val)
|
||||
_ => on_scalar(generator, ctx, arg_ty, arg_val)
|
||||
.unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])),
|
||||
};
|
||||
|
||||
@ -1110,12 +1110,12 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
arg: (Type, BasicValueEnum<'ctx>),
|
||||
derppening
commented
Nit: Please keep this parameter as Nit: Please keep this parameter as `n`. I don't think this change is necessary.
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let fn_name: &str = "abs";
|
||||
helper_call_numpy(
|
||||
const FN_NAME: &str = "abs";
|
||||
derppening marked this conversation as resolved
Outdated
derppening
commented
Is this change from Is this change from `const` to `let` necessary?
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
arg,
|
||||
fn_name,
|
||||
FN_NAME,
|
||||
&|_ctx, elem_ty| elem_ty,
|
||||
&|_generator, ctx, val_ty, val| match val {
|
||||
BasicValueEnum::IntValue(n) => Some({
|
||||
@ -1137,7 +1137,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx,
|
||||
n,
|
||||
ctx.ctx.bool_type().const_zero(),
|
||||
Some(fn_name),
|
||||
Some(FN_NAME),
|
||||
)
|
||||
.into()
|
||||
} else {
|
||||
@ -1148,7 +1148,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::FloatValue(n) => Some({
|
||||
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float));
|
||||
|
||||
llvm_intrinsics::call_float_fabs(ctx, n, Some(fn_name)).into()
|
||||
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
|
||||
}),
|
||||
|
||||
_ => None,
|
||||
@ -1156,40 +1156,88 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper macro. Used so we don't have to keep typing out the type signature of the function.
|
||||
macro_rules! call_numpy {
|
||||
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_elem:expr) => {
|
||||
/// Macro to conveniently generate numpy functions with [`helper_call_numpy_unary_elementwise`].
|
||||
derppening marked this conversation as resolved
Outdated
derppening
commented
Documentation comments are to explain what this macro does.
Same for the macros below. Documentation comments are to explain what this macro does.
```rs
/// Macro for generating unary functions accepting `ndarray`-compatible values.
///
/// The only form of the macro accepts an identifier for the generated function, the name of the
/// corresponding Python function, a closure of type `Fn(&mut CodeGenContext<'ctx, '_>, Type) ->
/// Type` which obtains the elementwise return type, and a closure of type `Fn(&mut G,
/// &mut CodeGenContext<'ctx, '_>, Type, BasicValueEnum<'ctx>) -> Option<BasicValueEnum<'ctx>>`
/// which performs the operation on a single element.
```
Same for the macros below.
|
||||
///
|
||||
derppening marked this conversation as resolved
Outdated
derppening
commented
I don't think it's appropriate to name the macro Same for the macros below. I don't think it's appropriate to name the macro `call_numpy`, as it gives the impression that the macro actually invokes the function rather than generates a function which can in turn be used.
Same for the macros below.
|
||||
/// Arguments:
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
|
||||
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// But there is no need to make it a reference.
|
||||
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// But there is no need to make it a reference.
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise {
|
||||
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
pub fn $name<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
arg: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
helper_call_numpy(generator, ctx, arg, $fn_name, &$get_ret_elem_type, &$on_elem)
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
arg,
|
||||
$fn_name,
|
||||
&$get_ret_elem_type,
|
||||
&$on_scalar,
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Helper macro, only used by [`call_numpy_isnan`] and [`call_numpy_isinf`]
|
||||
macro_rules! call_numpy_ret_bool {
|
||||
($name:ident, $fn_name:literal, $elem_call:expr) => {
|
||||
call_numpy!($name, $fn_name, |ctx, _| ctx.primitives.bool, |generator, ctx, n_ty, val| {
|
||||
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns boolean (as an `i8`) elementwise.
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
|
||||
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
|
||||
/// ```rust
|
||||
derppening
commented
Nit: Please add a newline between this and the last line. Same for the macro below. Nit: Please add a newline between this and the last line.
Same for the macro below.
|
||||
/// // Type of `$on_scalar:expr`
|
||||
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// generator: &mut G,
|
||||
/// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
/// arg: FloatValue<'ctx>
|
||||
/// ) -> IntValue<'ctx> // of LLVM type `i1`
|
||||
/// ```
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_bool {
|
||||
($name:ident, $fn_name:literal, $on_scalar:expr) => {
|
||||
create_helper_call_numpy_unary_elementwise!(
|
||||
$name,
|
||||
$fn_name,
|
||||
|ctx, _| ctx.primitives.bool,
|
||||
|generator, ctx, n_ty, val| {
|
||||
match val {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
|
||||
let ret = $elem_call(generator, ctx, n);
|
||||
let ret = $on_scalar(generator, ctx, n);
|
||||
Some(generator.bool_to_i8(ctx, ret).into())
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! call_numpy_float {
|
||||
/// A specialized version of [`create_helper_call_numpy_unary_elementwise`] to generate functions that takes in float and returns float elementwise.
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns float results.
|
||||
/// ```rust
|
||||
/// // Type of `$on_scalar:expr`
|
||||
/// fn on_scalar<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// generator: &mut G,
|
||||
/// ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
/// arg: FloatValue<'ctx>
|
||||
/// ) -> FloatValue<'ctx>
|
||||
/// ```
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise_float_to_float {
|
||||
($name:ident, $fn_name:literal, $elem_call:expr) => {
|
||||
call_numpy!(
|
||||
create_helper_call_numpy_unary_elementwise!(
|
||||
$name,
|
||||
$fn_name,
|
||||
|ctx, _| ctx.primitives.float,
|
||||
@ -1207,49 +1255,165 @@ macro_rules! call_numpy_float {
|
||||
};
|
||||
}
|
||||
|
||||
call_numpy_ret_bool!(call_numpy_isnan, "np_isnan", irrt::call_isnan);
|
||||
call_numpy_ret_bool!(call_numpy_isinf, "np_isinf", irrt::call_isinf);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_bool!(
|
||||
call_numpy_isnan,
|
||||
"np_isnan",
|
||||
irrt::call_isnan
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_bool!(
|
||||
call_numpy_isinf,
|
||||
"np_isinf",
|
||||
irrt::call_isinf
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_sin, "np_sin", llvm_intrinsics::call_float_sin);
|
||||
call_numpy_float!(call_numpy_cos, "np_cos", llvm_intrinsics::call_float_cos);
|
||||
call_numpy_float!(call_numpy_tan, "np_tan", extern_fns::call_tan);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_sin,
|
||||
"np_sin",
|
||||
llvm_intrinsics::call_float_sin
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_cos,
|
||||
"np_cos",
|
||||
llvm_intrinsics::call_float_cos
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_tan,
|
||||
"np_tan",
|
||||
extern_fns::call_tan
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_arcsin, "np_arcsin", extern_fns::call_asin);
|
||||
call_numpy_float!(call_numpy_arccos, "np_arccos", extern_fns::call_acos);
|
||||
call_numpy_float!(call_numpy_arctan, "np_arctan", extern_fns::call_atan);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arcsin,
|
||||
"np_arcsin",
|
||||
extern_fns::call_asin
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arccos,
|
||||
"np_arccos",
|
||||
extern_fns::call_acos
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arctan,
|
||||
"np_arctan",
|
||||
extern_fns::call_atan
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_sinh, "np_sinh", extern_fns::call_sinh);
|
||||
call_numpy_float!(call_numpy_cosh, "np_cosh", extern_fns::call_cosh);
|
||||
call_numpy_float!(call_numpy_tanh, "np_tanh", extern_fns::call_tanh);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_sinh,
|
||||
"np_sinh",
|
||||
extern_fns::call_sinh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_cosh,
|
||||
"np_cosh",
|
||||
extern_fns::call_cosh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_tanh,
|
||||
"np_tanh",
|
||||
extern_fns::call_tanh
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_arcsinh, "np_arcsinh", extern_fns::call_asinh);
|
||||
call_numpy_float!(call_numpy_arccosh, "np_arccosh", extern_fns::call_acosh);
|
||||
call_numpy_float!(call_numpy_arctanh, "np_arctanh", extern_fns::call_atanh);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arcsinh,
|
||||
"np_arcsinh",
|
||||
extern_fns::call_asinh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arccosh,
|
||||
"np_arccosh",
|
||||
extern_fns::call_acosh
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_arctanh,
|
||||
"np_arctanh",
|
||||
extern_fns::call_atanh
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_exp, "np_exp", llvm_intrinsics::call_float_exp);
|
||||
call_numpy_float!(call_numpy_exp2, "np_exp2", llvm_intrinsics::call_float_exp2);
|
||||
call_numpy_float!(call_numpy_expm1, "np_expm1", extern_fns::call_expm1);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_exp,
|
||||
"np_exp",
|
||||
llvm_intrinsics::call_float_exp
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_exp2,
|
||||
"np_exp2",
|
||||
llvm_intrinsics::call_float_exp2
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_expm1,
|
||||
"np_expm1",
|
||||
extern_fns::call_expm1
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_log, "np_log", llvm_intrinsics::call_float_log);
|
||||
call_numpy_float!(call_numpy_log2, "np_log2", llvm_intrinsics::call_float_log2);
|
||||
call_numpy_float!(call_numpy_log10, "np_log10", llvm_intrinsics::call_float_log10);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_log,
|
||||
"np_log",
|
||||
llvm_intrinsics::call_float_log
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_log2,
|
||||
"np_log2",
|
||||
llvm_intrinsics::call_float_log2
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_log10,
|
||||
"np_log10",
|
||||
llvm_intrinsics::call_float_log10
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_sqrt, "np_sqrt", llvm_intrinsics::call_float_sqrt);
|
||||
call_numpy_float!(call_numpy_cbrt, "np_cbrt", extern_fns::call_cbrt);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_sqrt,
|
||||
"np_sqrt",
|
||||
llvm_intrinsics::call_float_sqrt
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_cbrt,
|
||||
"np_cbrt",
|
||||
extern_fns::call_cbrt
|
||||
);
|
||||
|
||||
call_numpy_float!(call_numpy_fabs, "np_fabs", llvm_intrinsics::call_float_fabs);
|
||||
call_numpy_float!(call_numpy_rint, "np_rint", llvm_intrinsics::call_float_roundeven);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_fabs,
|
||||
"np_fabs",
|
||||
llvm_intrinsics::call_float_fabs
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_numpy_rint,
|
||||
"np_rint",
|
||||
llvm_intrinsics::call_float_roundeven
|
||||
);
|
||||
|
||||
call_numpy_float!(call_scipy_special_erf, "sp_spec_erf", extern_fns::call_erf);
|
||||
call_numpy_float!(call_scipy_special_erfc, "sp_spec_erfc", extern_fns::call_erfc);
|
||||
call_numpy_float!(call_scipy_special_gamma, "sp_spec_gamma", |ctx, val, _| irrt::call_gamma(
|
||||
ctx, val
|
||||
));
|
||||
call_numpy_float!(call_scipy_special_gammaln, "sp_spec_gammaln", |ctx, val, _| irrt::call_gammaln(
|
||||
ctx, val
|
||||
));
|
||||
call_numpy_float!(call_scipy_special_j0, "sp_spec_j0", |ctx, val, _| irrt::call_j0(ctx, val));
|
||||
call_numpy_float!(call_scipy_special_j1, "sp_spec_j1", extern_fns::call_j1);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_erf,
|
||||
"sp_spec_erf",
|
||||
extern_fns::call_erf
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_erfc,
|
||||
"sp_spec_erfc",
|
||||
extern_fns::call_erfc
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_gamma,
|
||||
"sp_spec_gamma",
|
||||
|ctx, val, _| irrt::call_gamma(ctx, val)
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_gammaln,
|
||||
"sp_spec_gammaln",
|
||||
|ctx, val, _| irrt::call_gammaln(ctx, val)
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_j0,
|
||||
"sp_spec_j0",
|
||||
|ctx, val, _| irrt::call_j0(ctx, val)
|
||||
);
|
||||
create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
call_scipy_special_j1,
|
||||
"sp_spec_j1",
|
||||
extern_fns::call_j1
|
||||
);
|
||||
|
||||
/// Invokes the `np_arctan2` builtin function.
|
||||
pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
Loading…
Reference in New Issue
Block a user
G
into thewhere
clause.call_numpy
makes it seem like you can call any numpy function. Perhaps something likecall_numpy_unaryfunc
?