core: Rework gamma to match SciPy behavior
This commit is contained in:
parent
cae5a0e56d
commit
37fe208c33
|
@ -1,7 +1,9 @@
|
|||
use super::*;
|
||||
use crate::{
|
||||
codegen::{
|
||||
expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor,
|
||||
expr::destructure_range,
|
||||
irrt::{calculate_len_for_slice_range, call_isinf, call_isnan},
|
||||
stmt::exn_constructor,
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
};
|
||||
|
@ -1428,14 +1430,83 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
"erfc",
|
||||
&[],
|
||||
),
|
||||
create_fn_by_extern(
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"gamma",
|
||||
float,
|
||||
&[(float, "z")],
|
||||
"tgamma",
|
||||
&[],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let float = ctx.primitives.float;
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
assert!(ctx.unifier.unioned(x_ty, float));
|
||||
|
||||
let tgamma_fn = ctx.module.get_function("tgamma").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
let func = ctx.module.add_function("tgamma", fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
|
||||
);
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
// %0 = call f64 @gamma(f64 %x)
|
||||
let call = ctx.builder
|
||||
.build_call(tgamma_fn, &[x_val.into()], "gamma")
|
||||
.try_as_basic_value()
|
||||
.left()
|
||||
.unwrap()
|
||||
.into_float_value();
|
||||
|
||||
// Handling for denormals
|
||||
// | x | Python gamma(x) | C tgamma(x) |
|
||||
// | ----------------- | --------------- | ----------- |
|
||||
// | nan | nan | nan |
|
||||
// | -inf | -inf | inf |
|
||||
// | inf | inf | inf |
|
||||
// | 0.0 | inf | inf |
|
||||
// | {-1.0, -2.0, ...} | inf | nan |
|
||||
//
|
||||
// Therefore, we remap to Python's denorm handling by:
|
||||
//
|
||||
// let v = gamma(x);
|
||||
// v = if isinf(v) { if isinf(x) { x } else { f64::INFINITY } } else { v }
|
||||
// v = if isnan(v) { if isnan(x) { x } else { f64::INFINITY } } else { v }
|
||||
|
||||
// %1 = call i32 @isnan(f64 %x)
|
||||
let is_arg_nan = call_isnan(ctx, x_val.into_float_value());
|
||||
// %2 = call i32 @isnan(f64 %0)
|
||||
let is_val_nan = call_isnan(ctx, call);
|
||||
// %3 = call i32 @isinf(f64 %x)
|
||||
let is_arg_inf = call_isinf(ctx, x_val.into_float_value());
|
||||
// %4 = call i32 @isinf(f64 %0)
|
||||
let is_val_inf = call_isinf(ctx, call);
|
||||
|
||||
// %5 = select i1 %3, f64 %x, f64 infinity
|
||||
let val_if_inf = ctx.builder
|
||||
.build_select(generator.bool_to_i1(ctx, is_arg_inf), x_val, llvm_f64.const_float(f64::INFINITY).into(), "")
|
||||
.into_float_value();
|
||||
// %6 = select i1 %2, f64 %x, f64 infinity
|
||||
let val_if_nan = ctx.builder
|
||||
.build_select(generator.bool_to_i1(ctx, is_arg_nan), x_val, llvm_f64.const_float(f64::INFINITY).into(), "")
|
||||
.into_float_value();
|
||||
// %7 = select i1 %4, f64 %5, f64 %0
|
||||
let val_or_inf = ctx.builder
|
||||
.build_select(generator.bool_to_i1(ctx, is_val_inf), val_if_inf, call.into(), "")
|
||||
.into_float_value();
|
||||
// %8 = select i1 %1, f64 %6, f64 %7
|
||||
let val = ctx.builder
|
||||
.build_select(generator.bool_to_i1(ctx, is_val_nan), val_if_nan, val_or_inf, "");
|
||||
|
||||
Ok(val.into())
|
||||
}),
|
||||
),
|
||||
create_fn_by_extern(
|
||||
primitives,
|
||||
|
|
Loading…
Reference in New Issue