From ff1fed112cd12240cfa20a782c37209b4706750c Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 10 Oct 2023 18:19:36 +0800 Subject: [PATCH] core: Rework gamma/gammaln to match SciPy behavior Matches behavior for infinities and NaNs. --- nac3core/src/toplevel/builtins.rs | 190 ++++++++++++++++++++++++++++-- 1 file changed, 181 insertions(+), 9 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 13db0931..7b2f9b67 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1470,32 +1470,204 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "erfc", &[], ), - create_fn_by_extern( + create_fn_by_codegen( primitives, &var_map, "gamma", float, &[(float, "z")], - "tgamma", - &[], + Box::new(|ctx, _, fun, args, generator| { + let float = ctx.primitives.float; + let llvm_f64 = ctx.ctx.f64_type(); + + let z_ty = fun.0.args[0].ty; + let z_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, z_ty)?; + + assert!(ctx.unifier.unioned(z_ty, float)); + + let tgamma_fn = ctx.module.get_function("tgamma").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + let func = ctx.module.add_function("tgamma", fn_type, None); + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ); + + func + }); + + // %0 = call f64 @tgamma(f64 %z) + let call = ctx.builder + .build_call(tgamma_fn, &[z_val.into()], "gamma") + .try_as_basic_value() + .left() + .unwrap() + .into_float_value(); + + // Handling for denormals + // | x | Python gamma(x) | C tgamma(x) | + // --- | ----------------- | --------------- | ----------- | + // (1) | nan | nan | nan | + // (2) | -inf | -inf | inf | + // (3) | inf | inf | inf | + // (4) | 0.0 | inf | inf | + // (5) | {-1.0, -2.0, ...} | inf | nan | + // + // Therefore, we remap to Python's denorm handling by: + // + // let v = tgamma(x); + // v = if isinf(v) || isnan(v) { f64::INFINITY } else { v } // Handles (4)-(5) + // v = if isinf(x) || isnan(x) { x } else { v } // Handles (1)-(3) + + // %v.isinf = call i32 @__nac3_isinf(f64 %0) + // %v.isinf.tobool = icmp ne i32 %v.isinf, 0 + let v_isinf = call_isinf(generator, ctx, call.into()); + // %v.isnan = call i32 @__nac3_isnan(f64 %0) + // %v.isnan.tobool = icmp ne i32 %v.isnan, 0 + let v_isnan = call_isnan(generator, ctx, call.into()); + + // %or = or i1 %v.isinf.tobool, %v.isnan.tobool + // %3 = select i1 %or, f64 inf, f64 %0 + let v_is_nonnum = ctx.builder.build_or(v_isinf, v_isnan, ""); + let val = ctx.builder.build_select( + v_is_nonnum, + llvm_f64.const_float(f64::INFINITY).into(), + call, + "", + ).into_float_value(); + + // %z.isinf = call i32 @__nac3_isinf(f64 %z) + // %z.isinf.tobool = icmp ne i32 %z.isinf, 0 + let z_isinf = call_isinf(generator, ctx, z_val.into_float_value()); + // %z.isnan = call i32 @__nac3_isnan(f64 %z) + // %z.isnan.tobool = icmp ne i32 %z.isnan, 0 + let z_isnan = call_isnan(generator, ctx, z_val.into_float_value()); + + // %or = or i1 %z.isinf.tobool, %z.isnan.tobool + // %val = select i1 %or, f64 %z, f64 %3 + let z_is_nonnum = ctx.builder.build_or(z_isinf, z_isnan, ""); + let val = ctx.builder.build_select( + z_is_nonnum, + z_val.into_float_value(), + val, + "", + ); + + Ok(val.into()) + }), ), - create_fn_by_extern( + create_fn_by_codegen( primitives, &var_map, "gammaln", float, &[(float, "x")], - "lgamma", - &[], + Box::new(|ctx, _, fun, args, generator| { + let float = ctx.primitives.float; + let llvm_f64 = ctx.ctx.f64_type(); + + let x_ty = fun.0.args[0].ty; + let x_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x_ty)?; + + assert!(ctx.unifier.unioned(x_ty, float)); + + let tgamma_fn = ctx.module.get_function("lgamma").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + let func = ctx.module.add_function("lgamma", fn_type, None); + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ); + + func + }); + + // %0 = call f64 @gamma(f64 %x) + let call = ctx.builder + .build_call(tgamma_fn, &[x_val.into()], "gammaln") + .try_as_basic_value() + .left() + .unwrap() + .into_float_value(); + + // libm's handling of value overflows differs from scipy: + // - scipy: gammaln(-inf) -> -inf + // - libm : lgamma(-inf) -> inf + // + // Therefore we remap it by: + // + // let v = lgamma(x); + // v = if isinf(x) { x } else { v } + + // %isinf = call i32 @__nac3_isinf(f64 %x) + // %tobool = icmp ne i32 %isinf, 0 + // %val = select i1 %tobool, f64 %x, f64 %0 + let v = ctx.builder.build_select( + call_isinf(generator, ctx, x_val.into_float_value()), + x_val, + call.into(), + "" + ); + + Ok(v.into()) + }), ), - create_fn_by_extern( + create_fn_by_codegen( primitives, &var_map, "j0", float, &[(float, "x")], - "j0", - &[], + Box::new(|ctx, _, fun, args, generator| { + let float = ctx.primitives.float; + let llvm_f64 = ctx.ctx.f64_type(); + + let x_ty = fun.0.args[0].ty; + let x_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x_ty)?; + + assert!(ctx.unifier.unioned(x_ty, float)); + + let tgamma_fn = ctx.module.get_function("j0").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + let func = ctx.module.add_function("j0", fn_type, None); + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ); + + func + }); + + // %0 = call f64 @j0(f64 %x) + let call = ctx.builder + .build_call(tgamma_fn, &[x_val.into()], "j0") + .try_as_basic_value() + .left() + .unwrap() + .into_float_value(); + + // libm's handling of value overflows differs from scipy: + // - scipy: j0(inf) -> nan + // - libm : j0(inf) -> 0.0 + // + // Therefore we remap it by: + // + // let v = j0(x); + // v = if isinf(x) { f64::NAN } else { v } + + // %1 = call i32 @__nac3_isinf(f64 %x) + // %tobool = icmp ne i32 %isinf, 0 + let arg_isinf = call_isinf(generator, ctx, x_val.into_float_value()); + + // %val = select i1 %tobool, f64 nan, f64 %0 + let val = ctx.builder + .build_select(arg_isinf, llvm_f64.const_float(f64::NAN), call, ""); + + Ok(val.into()) + }), ), create_fn_by_extern( primitives,