core: Rework gamma to match SciPy behavior

David Mak 2023-10-10 18:19:36 +08:00
parent cae5a0e56d
commit 37fe208c33
1 changed files with 75 additions and 4 deletions

View File

@ -1,7 +1,9 @@
use super::*;
use crate::{
codegen::{
expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor,
expr::destructure_range,
irrt::{calculate_len_for_slice_range, call_isinf, call_isnan},
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
};
@ -1428,14 +1430,83 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"erfc",
&[],
),
create_fn_by_extern(
create_fn_by_codegen(
primitives,
&var_map,
"gamma",
float,
&[(float, "z")],
"tgamma",
&[],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let tgamma_fn = ctx.module.get_function("tgamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function("tgamma", fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
// %0 = call f64 @gamma(f64 %x)
let call = ctx.builder
.build_call(tgamma_fn, &[x_val.into()], "gamma")
.try_as_basic_value()
.left()
.unwrap()
.into_float_value();
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// | ----------------- | --------------- | ----------- |
// | nan | nan | nan |
// | -inf | -inf | inf |
// | inf | inf | inf |
// | 0.0 | inf | inf |
// | {-1.0, -2.0, ...} | inf | nan |
//
// Therefore, we remap to Python's denorm handling by:
//
// let v = gamma(x);
// v = if isinf(v) { if isinf(x) { x } else { f64::INFINITY } } else { v }
// v = if isnan(v) { if isnan(x) { x } else { f64::INFINITY } } else { v }
// %1 = call i32 @isnan(f64 %x)
let is_arg_nan = call_isnan(ctx, x_val.into_float_value());
// %2 = call i32 @isnan(f64 %0)
let is_val_nan = call_isnan(ctx, call);
// %3 = call i32 @isinf(f64 %x)
let is_arg_inf = call_isinf(ctx, x_val.into_float_value());
// %4 = call i32 @isinf(f64 %0)
let is_val_inf = call_isinf(ctx, call);
// %5 = select i1 %3, f64 %x, f64 infinity
let val_if_inf = ctx.builder
.build_select(generator.bool_to_i1(ctx, is_arg_inf), x_val, llvm_f64.const_float(f64::INFINITY).into(), "")
.into_float_value();
// %6 = select i1 %2, f64 %x, f64 infinity
let val_if_nan = ctx.builder
.build_select(generator.bool_to_i1(ctx, is_arg_nan), x_val, llvm_f64.const_float(f64::INFINITY).into(), "")
.into_float_value();
// %7 = select i1 %4, f64 %5, f64 %0
let val_or_inf = ctx.builder
.build_select(generator.bool_to_i1(ctx, is_val_inf), val_if_inf, call.into(), "")
.into_float_value();
// %8 = select i1 %1, f64 %6, f64 %7
let val = ctx.builder
.build_select(generator.bool_to_i1(ctx, is_val_nan), val_if_nan, val_or_inf, "");
Ok(val.into())
}),
),
create_fn_by_extern(
primitives,