From 08a5050f9a272179cb5d590cb3cc55c944703a25 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Nov 2023 12:57:23 +0800 Subject: [PATCH] core: Implement non-trivial builtin functions using IRRT --- nac3core/src/codegen/irrt/irrt.c | 51 +++++++++ nac3core/src/codegen/irrt/mod.rs | 57 ++++++++++ nac3core/src/toplevel/builtins.rs | 180 ++++-------------------------- 3 files changed, 129 insertions(+), 159 deletions(-) diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 410813d..d68b344 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -145,4 +145,55 @@ int32_t __nac3_isinf(double x) { int32_t __nac3_isnan(double x) { return __builtin_isnan(x); +} + +double tgamma(double arg); + +double __nac3_gamma(double z) { + // Handling for denormals + // | x | Python gamma(x) | C tgamma(x) | + // --- | ----------------- | --------------- | ----------- | + // (1) | nan | nan | nan | + // (2) | -inf | -inf | inf | + // (3) | inf | inf | inf | + // (4) | 0.0 | inf | inf | + // (5) | {-1.0, -2.0, ...} | inf | nan | + + // (1)-(3) + if (__builtin_isinf(z) || __builtin_isnan(z)) { + return z; + } + + double v = tgamma(z); + + // (4)-(5) + return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v; +} + +double lgamma(double arg); + +double __nac3_gammaln(double x) { + // libm's handling of value overflows differs from scipy: + // - scipy: gammaln(-inf) -> -inf + // - libm : lgamma(-inf) -> inf + + if (__builtin_isinf(x)) { + return x; + } + + return lgamma(x); +} + +double j0(double x); + +double __nac3_j0(double x) { + // libm's handling of value overflows differs from scipy: + // - scipy: j0(inf) -> nan + // - libm : j0(inf) -> 0.0 + + if (__builtin_isinf(x)) { + return __builtin_nan(""); + } + + return j0(x); } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 527700f..aa4c10e 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -472,3 +472,60 @@ pub fn call_isnan<'ctx, 'a>( generator.bool_to_i1(ctx, ret) } + +/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. +pub fn call_gamma<'ctx, 'a>( + ctx: &CodeGenContext<'ctx, 'a>, + v: FloatValue<'ctx>, +) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_gamma", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "gamma") + .try_as_basic_value() + .unwrap_left() + .into_float_value() +} + +/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. +pub fn call_gammaln<'ctx, 'a>( + ctx: &CodeGenContext<'ctx, 'a>, + v: FloatValue<'ctx>, +) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_gammaln", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "gammaln") + .try_as_basic_value() + .unwrap_left() + .into_float_value() +} + +/// Generates a call to `j0` in IR. Returns an `f64` representing the result. +pub fn call_j0<'ctx, 'a>( + ctx: &CodeGenContext<'ctx, 'a>, + v: FloatValue<'ctx>, +) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_j0", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "j0") + .try_as_basic_value() + .unwrap_left() + .into_float_value() +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 514664c..3448b60 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -2,7 +2,14 @@ use super::*; use crate::{ codegen::{ expr::destructure_range, - irrt::{calculate_len_for_slice_range, call_isinf, call_isnan}, + irrt::{ + calculate_len_for_slice_range, + call_gamma, + call_gammaln, + call_isinf, + call_isnan, + call_j0, + }, stmt::exn_constructor, }, symbol_resolver::SymbolValue, @@ -1675,7 +1682,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { &[(float, "z")], Box::new(|ctx, _, fun, args, generator| { let float = ctx.primitives.float; - let llvm_f64 = ctx.ctx.f64_type(); let z_ty = fun.0.args[0].ty; let z_val = args[0].1.clone() @@ -1683,77 +1689,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { assert!(ctx.unifier.unioned(z_ty, float)); - let tgamma_fn = ctx.module.get_function("tgamma").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - let func = ctx.module.add_function("tgamma", fn_type, None); - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) - ); - - func - }); - - // %0 = call f64 @tgamma(f64 %z) - let call = ctx.builder - .build_call(tgamma_fn, &[z_val.into()], "gamma") - .try_as_basic_value() - .left() - .unwrap() - .into_float_value(); - - // Handling for denormals - // | x | Python gamma(x) | C tgamma(x) | - // --- | ----------------- | --------------- | ----------- | - // (1) | nan | nan | nan | - // (2) | -inf | -inf | inf | - // (3) | inf | inf | inf | - // (4) | 0.0 | inf | inf | - // (5) | {-1.0, -2.0, ...} | inf | nan | - // - // Therefore, we remap to Python's denorm handling by: - // - // let v = tgamma(x); - // v = if isinf(v) || isnan(v) { f64::INFINITY } else { v } // Handles (4)-(5) - // v = if isinf(x) || isnan(x) { x } else { v } // Handles (1)-(3) - - // %v.isinf = call i32 @__nac3_isinf(f64 %0) - // %v.isinf.tobool = icmp ne i32 %v.isinf, 0 - let v_isinf = call_isinf(generator, ctx, call.into()); - // %v.isnan = call i32 @__nac3_isnan(f64 %0) - // %v.isnan.tobool = icmp ne i32 %v.isnan, 0 - let v_isnan = call_isnan(generator, ctx, call.into()); - - // %or = or i1 %v.isinf.tobool, %v.isnan.tobool - // %3 = select i1 %or, f64 inf, f64 %0 - let v_is_nonnum = ctx.builder.build_or(v_isinf, v_isnan, ""); - let val = ctx.builder.build_select( - v_is_nonnum, - llvm_f64.const_float(f64::INFINITY).into(), - call, - "", - ).into_float_value(); - - // %z.isinf = call i32 @__nac3_isinf(f64 %z) - // %z.isinf.tobool = icmp ne i32 %z.isinf, 0 - let z_isinf = call_isinf(generator, ctx, z_val.into_float_value()); - // %z.isnan = call i32 @__nac3_isnan(f64 %z) - // %z.isnan.tobool = icmp ne i32 %z.isnan, 0 - let z_isnan = call_isnan(generator, ctx, z_val.into_float_value()); - - // %or = or i1 %z.isinf.tobool, %z.isnan.tobool - // %val = select i1 %or, f64 %z, f64 %3 - let z_is_nonnum = ctx.builder.build_or(z_isinf, z_isnan, ""); - let val = ctx.builder.build_select( - z_is_nonnum, - z_val.into_float_value(), - val, - "", - ); - - Ok(val.into()) - }), - ), + Ok(Some(call_gamma(ctx, z_val.into_float_value()).into())) + } + )), create_fn_by_codegen( primitives, &var_map, @@ -1762,53 +1700,14 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { &[(float, "x")], Box::new(|ctx, _, fun, args, generator| { let float = ctx.primitives.float; - let llvm_f64 = ctx.ctx.f64_type(); - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let z_ty = fun.0.args[0].ty; + let z_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, z_ty)?; - assert!(ctx.unifier.unioned(x_ty, float)); + assert!(ctx.unifier.unioned(z_ty, float)); - let tgamma_fn = ctx.module.get_function("lgamma").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - let func = ctx.module.add_function("lgamma", fn_type, None); - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) - ); - - func - }); - - // %0 = call f64 @gamma(f64 %x) - let call = ctx.builder - .build_call(tgamma_fn, &[x_val.into()], "gammaln") - .try_as_basic_value() - .left() - .unwrap() - .into_float_value(); - - // libm's handling of value overflows differs from scipy: - // - scipy: gammaln(-inf) -> -inf - // - libm : lgamma(-inf) -> inf - // - // Therefore we remap it by: - // - // let v = lgamma(x); - // v = if isinf(x) { x } else { v } - - // %isinf = call i32 @__nac3_isinf(f64 %x) - // %tobool = icmp ne i32 %isinf, 0 - // %val = select i1 %tobool, f64 %x, f64 %0 - let v = ctx.builder.build_select( - call_isinf(generator, ctx, x_val.into_float_value()), - x_val, - call.into(), - "" - ); - - Ok(v.into()) + Ok(Some(call_gammaln(ctx, z_val.into_float_value()).into())) }), ), create_fn_by_codegen( @@ -1819,51 +1718,14 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { &[(float, "x")], Box::new(|ctx, _, fun, args, generator| { let float = ctx.primitives.float; - let llvm_f64 = ctx.ctx.f64_type(); - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let z_ty = fun.0.args[0].ty; + let z_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, z_ty)?; - assert!(ctx.unifier.unioned(x_ty, float)); + assert!(ctx.unifier.unioned(z_ty, float)); - let tgamma_fn = ctx.module.get_function("j0").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - let func = ctx.module.add_function("j0", fn_type, None); - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) - ); - - func - }); - - // %0 = call f64 @j0(f64 %x) - let call = ctx.builder - .build_call(tgamma_fn, &[x_val.into()], "j0") - .try_as_basic_value() - .left() - .unwrap() - .into_float_value(); - - // libm's handling of value overflows differs from scipy: - // - scipy: j0(inf) -> nan - // - libm : j0(inf) -> 0.0 - // - // Therefore we remap it by: - // - // let v = j0(x); - // v = if isinf(x) { f64::NAN } else { v } - - // %1 = call i32 @__nac3_isinf(f64 %x) - // %tobool = icmp ne i32 %isinf, 0 - let arg_isinf = call_isinf(generator, ctx, x_val.into_float_value()); - - // %val = select i1 %tobool, f64 nan, f64 %0 - let val = ctx.builder - .build_select(arg_isinf, llvm_f64.const_float(f64::NAN), call, ""); - - Ok(val.into()) + Ok(Some(call_j0(ctx, z_val.into_float_value()).into())) }), ), create_fn_by_extern(