forked from M-Labs/nac3
core: Implement non-trivial builtin functions using IRRT
This commit is contained in:
parent
c2ab6b58ff
commit
08a5050f9a
@ -145,4 +145,55 @@ int32_t __nac3_isinf(double x) {
|
|||||||
|
|
||||||
int32_t __nac3_isnan(double x) {
|
int32_t __nac3_isnan(double x) {
|
||||||
return __builtin_isnan(x);
|
return __builtin_isnan(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double tgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gamma(double z) {
|
||||||
|
// Handling for denormals
|
||||||
|
// | x | Python gamma(x) | C tgamma(x) |
|
||||||
|
// --- | ----------------- | --------------- | ----------- |
|
||||||
|
// (1) | nan | nan | nan |
|
||||||
|
// (2) | -inf | -inf | inf |
|
||||||
|
// (3) | inf | inf | inf |
|
||||||
|
// (4) | 0.0 | inf | inf |
|
||||||
|
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||||
|
|
||||||
|
// (1)-(3)
|
||||||
|
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
double v = tgamma(z);
|
||||||
|
|
||||||
|
// (4)-(5)
|
||||||
|
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
double lgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gammaln(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: gammaln(-inf) -> -inf
|
||||||
|
// - libm : lgamma(-inf) -> inf
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return lgamma(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double j0(double x);
|
||||||
|
|
||||||
|
double __nac3_j0(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: j0(inf) -> nan
|
||||||
|
// - libm : j0(inf) -> 0.0
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return __builtin_nan("");
|
||||||
|
}
|
||||||
|
|
||||||
|
return j0(x);
|
||||||
}
|
}
|
@ -472,3 +472,60 @@ pub fn call_isnan<'ctx, 'a>(
|
|||||||
|
|
||||||
generator.bool_to_i1(ctx, ret)
|
generator.bool_to_i1(ctx, ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_gamma<'ctx, 'a>(
|
||||||
|
ctx: &CodeGenContext<'ctx, 'a>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.unwrap_left()
|
||||||
|
.into_float_value()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_gammaln<'ctx, 'a>(
|
||||||
|
ctx: &CodeGenContext<'ctx, 'a>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.unwrap_left()
|
||||||
|
.into_float_value()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
||||||
|
pub fn call_j0<'ctx, 'a>(
|
||||||
|
ctx: &CodeGenContext<'ctx, 'a>,
|
||||||
|
v: FloatValue<'ctx>,
|
||||||
|
) -> FloatValue<'ctx> {
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||||
|
ctx.module.add_function("__nac3_j0", fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(intrinsic_fn, &[v.into()], "j0")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.unwrap_left()
|
||||||
|
.into_float_value()
|
||||||
|
}
|
||||||
|
@ -2,7 +2,14 @@ use super::*;
|
|||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
expr::destructure_range,
|
expr::destructure_range,
|
||||||
irrt::{calculate_len_for_slice_range, call_isinf, call_isnan},
|
irrt::{
|
||||||
|
calculate_len_for_slice_range,
|
||||||
|
call_gamma,
|
||||||
|
call_gammaln,
|
||||||
|
call_isinf,
|
||||||
|
call_isnan,
|
||||||
|
call_j0,
|
||||||
|
},
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
@ -1675,7 +1682,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||||||
&[(float, "z")],
|
&[(float, "z")],
|
||||||
Box::new(|ctx, _, fun, args, generator| {
|
Box::new(|ctx, _, fun, args, generator| {
|
||||||
let float = ctx.primitives.float;
|
let float = ctx.primitives.float;
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let z_ty = fun.0.args[0].ty;
|
let z_ty = fun.0.args[0].ty;
|
||||||
let z_val = args[0].1.clone()
|
let z_val = args[0].1.clone()
|
||||||
@ -1683,77 +1689,9 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||||||
|
|
||||||
assert!(ctx.unifier.unioned(z_ty, float));
|
assert!(ctx.unifier.unioned(z_ty, float));
|
||||||
|
|
||||||
let tgamma_fn = ctx.module.get_function("tgamma").unwrap_or_else(|| {
|
Ok(Some(call_gamma(ctx, z_val.into_float_value()).into()))
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
}
|
||||||
let func = ctx.module.add_function("tgamma", fn_type, None);
|
)),
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
|
|
||||||
);
|
|
||||||
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
// %0 = call f64 @tgamma(f64 %z)
|
|
||||||
let call = ctx.builder
|
|
||||||
.build_call(tgamma_fn, &[z_val.into()], "gamma")
|
|
||||||
.try_as_basic_value()
|
|
||||||
.left()
|
|
||||||
.unwrap()
|
|
||||||
.into_float_value();
|
|
||||||
|
|
||||||
// Handling for denormals
|
|
||||||
// | x | Python gamma(x) | C tgamma(x) |
|
|
||||||
// --- | ----------------- | --------------- | ----------- |
|
|
||||||
// (1) | nan | nan | nan |
|
|
||||||
// (2) | -inf | -inf | inf |
|
|
||||||
// (3) | inf | inf | inf |
|
|
||||||
// (4) | 0.0 | inf | inf |
|
|
||||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
|
||||||
//
|
|
||||||
// Therefore, we remap to Python's denorm handling by:
|
|
||||||
//
|
|
||||||
// let v = tgamma(x);
|
|
||||||
// v = if isinf(v) || isnan(v) { f64::INFINITY } else { v } // Handles (4)-(5)
|
|
||||||
// v = if isinf(x) || isnan(x) { x } else { v } // Handles (1)-(3)
|
|
||||||
|
|
||||||
// %v.isinf = call i32 @__nac3_isinf(f64 %0)
|
|
||||||
// %v.isinf.tobool = icmp ne i32 %v.isinf, 0
|
|
||||||
let v_isinf = call_isinf(generator, ctx, call.into());
|
|
||||||
// %v.isnan = call i32 @__nac3_isnan(f64 %0)
|
|
||||||
// %v.isnan.tobool = icmp ne i32 %v.isnan, 0
|
|
||||||
let v_isnan = call_isnan(generator, ctx, call.into());
|
|
||||||
|
|
||||||
// %or = or i1 %v.isinf.tobool, %v.isnan.tobool
|
|
||||||
// %3 = select i1 %or, f64 inf, f64 %0
|
|
||||||
let v_is_nonnum = ctx.builder.build_or(v_isinf, v_isnan, "");
|
|
||||||
let val = ctx.builder.build_select(
|
|
||||||
v_is_nonnum,
|
|
||||||
llvm_f64.const_float(f64::INFINITY).into(),
|
|
||||||
call,
|
|
||||||
"",
|
|
||||||
).into_float_value();
|
|
||||||
|
|
||||||
// %z.isinf = call i32 @__nac3_isinf(f64 %z)
|
|
||||||
// %z.isinf.tobool = icmp ne i32 %z.isinf, 0
|
|
||||||
let z_isinf = call_isinf(generator, ctx, z_val.into_float_value());
|
|
||||||
// %z.isnan = call i32 @__nac3_isnan(f64 %z)
|
|
||||||
// %z.isnan.tobool = icmp ne i32 %z.isnan, 0
|
|
||||||
let z_isnan = call_isnan(generator, ctx, z_val.into_float_value());
|
|
||||||
|
|
||||||
// %or = or i1 %z.isinf.tobool, %z.isnan.tobool
|
|
||||||
// %val = select i1 %or, f64 %z, f64 %3
|
|
||||||
let z_is_nonnum = ctx.builder.build_or(z_isinf, z_isnan, "");
|
|
||||||
let val = ctx.builder.build_select(
|
|
||||||
z_is_nonnum,
|
|
||||||
z_val.into_float_value(),
|
|
||||||
val,
|
|
||||||
"",
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(val.into())
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
primitives,
|
primitives,
|
||||||
&var_map,
|
&var_map,
|
||||||
@ -1762,53 +1700,14 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||||||
&[(float, "x")],
|
&[(float, "x")],
|
||||||
Box::new(|ctx, _, fun, args, generator| {
|
Box::new(|ctx, _, fun, args, generator| {
|
||||||
let float = ctx.primitives.float;
|
let float = ctx.primitives.float;
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let x_ty = fun.0.args[0].ty;
|
let z_ty = fun.0.args[0].ty;
|
||||||
let x_val = args[0].1.clone()
|
let z_val = args[0].1.clone()
|
||||||
.to_basic_value_enum(ctx, generator, x_ty)?;
|
.to_basic_value_enum(ctx, generator, z_ty)?;
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(x_ty, float));
|
assert!(ctx.unifier.unioned(z_ty, float));
|
||||||
|
|
||||||
let tgamma_fn = ctx.module.get_function("lgamma").unwrap_or_else(|| {
|
Ok(Some(call_gammaln(ctx, z_val.into_float_value()).into()))
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
let func = ctx.module.add_function("lgamma", fn_type, None);
|
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
|
|
||||||
);
|
|
||||||
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
// %0 = call f64 @gamma(f64 %x)
|
|
||||||
let call = ctx.builder
|
|
||||||
.build_call(tgamma_fn, &[x_val.into()], "gammaln")
|
|
||||||
.try_as_basic_value()
|
|
||||||
.left()
|
|
||||||
.unwrap()
|
|
||||||
.into_float_value();
|
|
||||||
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: gammaln(-inf) -> -inf
|
|
||||||
// - libm : lgamma(-inf) -> inf
|
|
||||||
//
|
|
||||||
// Therefore we remap it by:
|
|
||||||
//
|
|
||||||
// let v = lgamma(x);
|
|
||||||
// v = if isinf(x) { x } else { v }
|
|
||||||
|
|
||||||
// %isinf = call i32 @__nac3_isinf(f64 %x)
|
|
||||||
// %tobool = icmp ne i32 %isinf, 0
|
|
||||||
// %val = select i1 %tobool, f64 %x, f64 %0
|
|
||||||
let v = ctx.builder.build_select(
|
|
||||||
call_isinf(generator, ctx, x_val.into_float_value()),
|
|
||||||
x_val,
|
|
||||||
call.into(),
|
|
||||||
""
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(v.into())
|
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
@ -1819,51 +1718,14 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||||||
&[(float, "x")],
|
&[(float, "x")],
|
||||||
Box::new(|ctx, _, fun, args, generator| {
|
Box::new(|ctx, _, fun, args, generator| {
|
||||||
let float = ctx.primitives.float;
|
let float = ctx.primitives.float;
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let x_ty = fun.0.args[0].ty;
|
let z_ty = fun.0.args[0].ty;
|
||||||
let x_val = args[0].1.clone()
|
let z_val = args[0].1.clone()
|
||||||
.to_basic_value_enum(ctx, generator, x_ty)?;
|
.to_basic_value_enum(ctx, generator, z_ty)?;
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(x_ty, float));
|
assert!(ctx.unifier.unioned(z_ty, float));
|
||||||
|
|
||||||
let tgamma_fn = ctx.module.get_function("j0").unwrap_or_else(|| {
|
Ok(Some(call_j0(ctx, z_val.into_float_value()).into()))
|
||||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
|
||||||
let func = ctx.module.add_function("j0", fn_type, None);
|
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
|
|
||||||
);
|
|
||||||
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
// %0 = call f64 @j0(f64 %x)
|
|
||||||
let call = ctx.builder
|
|
||||||
.build_call(tgamma_fn, &[x_val.into()], "j0")
|
|
||||||
.try_as_basic_value()
|
|
||||||
.left()
|
|
||||||
.unwrap()
|
|
||||||
.into_float_value();
|
|
||||||
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: j0(inf) -> nan
|
|
||||||
// - libm : j0(inf) -> 0.0
|
|
||||||
//
|
|
||||||
// Therefore we remap it by:
|
|
||||||
//
|
|
||||||
// let v = j0(x);
|
|
||||||
// v = if isinf(x) { f64::NAN } else { v }
|
|
||||||
|
|
||||||
// %1 = call i32 @__nac3_isinf(f64 %x)
|
|
||||||
// %tobool = icmp ne i32 %isinf, 0
|
|
||||||
let arg_isinf = call_isinf(generator, ctx, x_val.into_float_value());
|
|
||||||
|
|
||||||
// %val = select i1 %tobool, f64 nan, f64 %0
|
|
||||||
let val = ctx.builder
|
|
||||||
.build_select(arg_isinf, llvm_f64.const_float(f64::NAN), call, "");
|
|
||||||
|
|
||||||
Ok(val.into())
|
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
create_fn_by_extern(
|
create_fn_by_extern(
|
||||||
|
Loading…
Reference in New Issue
Block a user