1
0
forked from M-Labs/nac3

core: Rework gamma/gammaln to match SciPy behavior

Matches behavior for infinities and NaNs.
This commit is contained in:
David Mak 2023-10-10 18:19:36 +08:00
parent 36a6a7b8cd
commit ff1fed112c

View File

@ -1470,32 +1470,204 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"erfc", "erfc",
&[], &[],
), ),
create_fn_by_extern( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"gamma", "gamma",
float, float,
&[(float, "z")], &[(float, "z")],
"tgamma", Box::new(|ctx, _, fun, args, generator| {
&[], let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let z_ty = fun.0.args[0].ty;
let z_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, z_ty)?;
assert!(ctx.unifier.unioned(z_ty, float));
let tgamma_fn = ctx.module.get_function("tgamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function("tgamma", fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
// %0 = call f64 @tgamma(f64 %z)
let call = ctx.builder
.build_call(tgamma_fn, &[z_val.into()], "gamma")
.try_as_basic_value()
.left()
.unwrap()
.into_float_value();
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
//
// Therefore, we remap to Python's denorm handling by:
//
// let v = tgamma(x);
// v = if isinf(v) || isnan(v) { f64::INFINITY } else { v } // Handles (4)-(5)
// v = if isinf(x) || isnan(x) { x } else { v } // Handles (1)-(3)
// %v.isinf = call i32 @__nac3_isinf(f64 %0)
// %v.isinf.tobool = icmp ne i32 %v.isinf, 0
let v_isinf = call_isinf(generator, ctx, call.into());
// %v.isnan = call i32 @__nac3_isnan(f64 %0)
// %v.isnan.tobool = icmp ne i32 %v.isnan, 0
let v_isnan = call_isnan(generator, ctx, call.into());
// %or = or i1 %v.isinf.tobool, %v.isnan.tobool
// %3 = select i1 %or, f64 inf, f64 %0
let v_is_nonnum = ctx.builder.build_or(v_isinf, v_isnan, "");
let val = ctx.builder.build_select(
v_is_nonnum,
llvm_f64.const_float(f64::INFINITY).into(),
call,
"",
).into_float_value();
// %z.isinf = call i32 @__nac3_isinf(f64 %z)
// %z.isinf.tobool = icmp ne i32 %z.isinf, 0
let z_isinf = call_isinf(generator, ctx, z_val.into_float_value());
// %z.isnan = call i32 @__nac3_isnan(f64 %z)
// %z.isnan.tobool = icmp ne i32 %z.isnan, 0
let z_isnan = call_isnan(generator, ctx, z_val.into_float_value());
// %or = or i1 %z.isinf.tobool, %z.isnan.tobool
// %val = select i1 %or, f64 %z, f64 %3
let z_is_nonnum = ctx.builder.build_or(z_isinf, z_isnan, "");
let val = ctx.builder.build_select(
z_is_nonnum,
z_val.into_float_value(),
val,
"",
);
Ok(val.into())
}),
), ),
create_fn_by_extern( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"gammaln", "gammaln",
float, float,
&[(float, "x")], &[(float, "x")],
"lgamma", Box::new(|ctx, _, fun, args, generator| {
&[], let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let tgamma_fn = ctx.module.get_function("lgamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function("lgamma", fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
// %0 = call f64 @gamma(f64 %x)
let call = ctx.builder
.build_call(tgamma_fn, &[x_val.into()], "gammaln")
.try_as_basic_value()
.left()
.unwrap()
.into_float_value();
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
//
// Therefore we remap it by:
//
// let v = lgamma(x);
// v = if isinf(x) { x } else { v }
// %isinf = call i32 @__nac3_isinf(f64 %x)
// %tobool = icmp ne i32 %isinf, 0
// %val = select i1 %tobool, f64 %x, f64 %0
let v = ctx.builder.build_select(
call_isinf(generator, ctx, x_val.into_float_value()),
x_val,
call.into(),
""
);
Ok(v.into())
}),
), ),
create_fn_by_extern( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"j0", "j0",
float, float,
&[(float, "x")], &[(float, "x")],
"j0", Box::new(|ctx, _, fun, args, generator| {
&[], let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let tgamma_fn = ctx.module.get_function("j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function("j0", fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
);
func
});
// %0 = call f64 @j0(f64 %x)
let call = ctx.builder
.build_call(tgamma_fn, &[x_val.into()], "j0")
.try_as_basic_value()
.left()
.unwrap()
.into_float_value();
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
//
// Therefore we remap it by:
//
// let v = j0(x);
// v = if isinf(x) { f64::NAN } else { v }
// %1 = call i32 @__nac3_isinf(f64 %x)
// %tobool = icmp ne i32 %isinf, 0
let arg_isinf = call_isinf(generator, ctx, x_val.into_float_value());
// %val = select i1 %tobool, f64 nan, f64 %0
let val = ctx.builder
.build_select(arg_isinf, llvm_f64.const_float(f64::NAN), call, "");
Ok(val.into())
}),
), ),
create_fn_by_extern( create_fn_by_extern(
primitives, primitives,