Compare commits

..

5 Commits

3 changed files with 116 additions and 77 deletions

View File

@ -433,11 +433,13 @@ pub fn list_slice_assignment<'ctx, 'a>(
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result. /// Generates a call to `isinf` in IR. Returns either an `i32` or `i1` representing the result,
/// depending on the value of `to_i1`.
pub fn call_isinf<'ctx, 'a>( pub fn call_isinf<'ctx, 'a>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, 'a>, ctx: &CodeGenContext<'ctx, 'a>,
v: FloatValue<'ctx>, v: FloatValue<'ctx>,
to_i1: bool,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
@ -449,14 +451,20 @@ pub fn call_isinf<'ctx, 'a>(
.try_as_basic_value() .try_as_basic_value()
.unwrap_left() .unwrap_left()
.into_int_value(); .into_int_value();
generator.bool_to_i1(ctx, val) if to_i1 {
generator.bool_to_i1(ctx, val)
} else {
val
}
} }
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. /// Generates a call to `isnan` in IR. Returns either an `i32` or `i1` representing the result,
/// depending on the value of `to_i1`.
pub fn call_isnan<'ctx, 'a>( pub fn call_isnan<'ctx, 'a>(
generator: &dyn CodeGenerator, generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, 'a>, ctx: &CodeGenContext<'ctx, 'a>,
v: FloatValue<'ctx>, v: FloatValue<'ctx>,
to_i1: bool,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
@ -468,5 +476,9 @@ pub fn call_isnan<'ctx, 'a>(
.try_as_basic_value() .try_as_basic_value()
.unwrap_left() .unwrap_left()
.into_int_value(); .into_int_value();
generator.bool_to_i1(ctx, val) if to_i1 {
generator.bool_to_i1(ctx, val)
} else {
val
}
} }

View File

@ -1180,23 +1180,45 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
create_fn_by_extern( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"isnan", "isnan",
boolean, boolean,
&[(float, "x")], &[(float, "x")],
"isnan", Box::new(|ctx, _, fun, args, generator| {
&[], let float = ctx.primitives.float;
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let val = call_isnan(generator, ctx, x_val.into_float_value(), true);
Ok(Some(val.into()))
}),
), ),
create_fn_by_extern( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
"isinf", "isinf",
boolean, boolean,
&[(float, "x")], &[(float, "x")],
"isinf", Box::new(|ctx, _, fun, args, generator| {
&[], let float = ctx.primitives.float;
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let val = call_isinf(generator, ctx, x_val.into_float_value(), true);
Ok(Some(val.into()))
}),
), ),
create_fn_by_intrinsic( create_fn_by_intrinsic(
primitives, primitives,
@ -1498,31 +1520,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// v = if isinf(v) || isnan(v) { f64::INFINITY } else { v } // Handles (4)-(5) // v = if isinf(v) || isnan(v) { f64::INFINITY } else { v } // Handles (4)-(5)
// v = if isinf(x) || isnan(x) { x } else { v } // Handles (1)-(3) // v = if isinf(x) || isnan(x) { x } else { v } // Handles (1)-(3)
// %isinf = call i32 @__nac3_isinf(f64 %0) // %v.isinf = call i32 @__nac3_isinf(f64 %0)
// %isinf_bool = icmp ne i32 %isinf, 0 let v_isinf = call_isinf(generator, ctx, call.into(), false);
let v_isinf = call_isinf(generator, ctx, call.into()); // %v.isnan = call i32 @__nac3_isnan(f64 %0)
// %isnan = call i32 @__nac3_isnan(f64 %0) let v_isnan = call_isnan(generator, ctx, call.into(), false);
// %isnan_bool = icmp ne i32 %isnan, 0
let v_isnan = call_isnan(generator, ctx, call.into());
// %or = or i1 %isinf_bool, %isinf_nan // %or = or i32 %v.isinf, %v.isnan
// %3 = select i1 %or, f64 inf, f64 %0 // %or.tobool = icmp ne i32 %or, 0
// %3 = select i1 %or.tobool, f64 inf, f64 %0
let v_is_nonnum = ctx.builder.build_or(v_isinf, v_isnan, "");
let val = ctx.builder.build_select( let val = ctx.builder.build_select(
ctx.builder.build_or(v_isinf, v_isnan, ""), generator.bool_to_i1(ctx, v_is_nonnum),
llvm_f64.const_float(f64::INFINITY).into(), llvm_f64.const_float(f64::INFINITY).into(),
call, call,
"", "",
).into_float_value(); ).into_float_value();
// %isinf = call i32 @__nac3_isinf(f64 %z) // %z.isinf = call i32 @__nac3_isinf(f64 %z)
// %isinf_bool = icmp ne i32 %isinf, 0 let z_isinf = call_isinf(generator, ctx, z_val.into_float_value(), false);
let z_isinf = call_isinf(generator, ctx, z_val.into_float_value()); // %z.isnan = call i32 @__nac3_isnan(f64 %z)
// %isnan = call i32 @__nac3_isnan(f64 %z) let z_isnan = call_isnan(generator, ctx, z_val.into_float_value(), false);
// %isnan_bool = icmp ne i32 %isnan, 0
let z_isnan = call_isnan(generator, ctx, z_val.into_float_value());
// %or = or i1 %isinf_bool, %isinf_nan // %or = or i32 %z.isinf, %z.isnan
// %4 = select i1 %or, f64 inf, f64 %0 // %or.tobool = icmp ne i32 %or, 0
// %val = select i1 %or.tobool, f64 %z, f64 %3
let val = ctx.builder.build_select( let val = ctx.builder.build_select(
ctx.builder.build_or(z_isinf, z_isnan, ""), ctx.builder.build_or(z_isinf, z_isnan, ""),
z_val.into_float_value(), z_val.into_float_value(),
@ -1581,7 +1602,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// %tobool = icmp ne i32 %isinf, 0 // %tobool = icmp ne i32 %isinf, 0
// %val = select i1 %tobool, f64 %x, f64 %0 // %val = select i1 %tobool, f64 %x, f64 %0
let v = ctx.builder.build_select( let v = ctx.builder.build_select(
call_isinf(generator, ctx, x_val.into_float_value()), call_isinf(generator, ctx, x_val.into_float_value(), true),
x_val, x_val,
call.into(), call.into(),
"" ""
@ -1636,7 +1657,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
// %1 = call i32 @__nac3_isinf(f64 %x) // %1 = call i32 @__nac3_isinf(f64 %x)
// %2 = // %2 =
let arg_isinf = call_isinf(generator, ctx, x_val.into_float_value()); let arg_isinf = call_isinf(generator, ctx, x_val.into_float_value(), true);
// %val = select i1 %1, f64 nan, f64 %0 // %val = select i1 %1, f64 nan, f64 %0
let val = ctx.builder let val = ctx.builder

View File

@ -14,136 +14,142 @@ def dbl_nan() -> float:
def dbl_inf() -> float: def dbl_inf() -> float:
... ...
def dbl_pi() -> float:
return 3.1415926535897932384626433
def dbl_e() -> float:
return 2.71828182845904523536028747135266249775724709369995
def test_isnan(): def test_isnan():
for x in [dbl_nan(), 0.0, dbl_inf()]: for x in [dbl_nan(), 0.0, dbl_inf()]:
output_bool(isnan(x)) output_bool(isnan(x))
def test_isinf(): def test_isinf():
for x in [dbl_inf(), 0.0, dbl_nan()]: for x in [dbl_inf(), -dbl_inf(), 0.0, dbl_nan()]:
output_bool(isinf(x)) output_bool(isinf(x))
def test_sin(): def test_sin():
pi = 3.1415926535897932384626433 pi = dbl_pi()
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]: for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(sin(x)) output_float64(sin(x))
def test_cos(): def test_cos():
pi = 3.1415926535897932384626433 pi = dbl_pi()
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]: for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(cos(x)) output_float64(cos(x))
def test_exp(): def test_exp():
for x in [0.0, 1.0]: for x in [0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(exp(x)) output_float64(exp(x))
def test_exp2(): def test_exp2():
for x in [0.0, 1.0]: for x in [0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(exp2(x)) output_float64(exp2(x))
def test_log(): def test_log():
e = 2.71828182845904523536028747135266249775724709369995 e = dbl_e()
for x in [1.0, e]: for x in [1.0, e, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(log(x)) output_float64(log(x))
def test_log10(): def test_log10():
for x in [1.0, 10.0]: for x in [1.0, 10.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(log10(x)) output_float64(log10(x))
def test_log2(): def test_log2():
for x in [1.0, 2.0]: for x in [1.0, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(log2(x)) output_float64(log2(x))
def test_fabs(): def test_fabs():
for x in [-1.0, 0.0, 1.0]: for x in [-1.0, 0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(fabs(x)) output_float64(fabs(x))
def test_floor(): def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(floor(x)) output_float64(floor(x))
def test_ceil(): def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(ceil(x)) output_float64(ceil(x))
def test_trunc(): def test_trunc():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(trunc(x)) output_float64(trunc(x))
def test_sqrt(): def test_sqrt():
for x in [1.0, 2.0, 4.0]: for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(sqrt(x)) output_float64(sqrt(x))
def test_rint(): def test_rint():
for x in [-1.5, -0.5, 0.5, 1.5]: for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(rint(x)) output_float64(rint(x))
def test_tan(): def test_tan():
pi = 3.1415926535897932384626433 pi = dbl_pi()
for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]: for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(tan(x)) output_float64(tan(x))
def test_arcsin(): def test_arcsin():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arcsin(x)) output_float64(arcsin(x))
def test_arccos(): def test_arccos():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arccos(x)) output_float64(arccos(x))
def test_arctan(): def test_arctan():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arctan(x)) output_float64(arctan(x))
def test_sinh(): def test_sinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(sinh(x)) output_float64(sinh(x))
def test_cosh(): def test_cosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(cosh(x)) output_float64(cosh(x))
def test_tanh(): def test_tanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(tanh(x)) output_float64(tanh(x))
def test_arcsinh(): def test_arcsinh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arcsinh(x)) output_float64(arcsinh(x))
def test_arccosh(): def test_arccosh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arccosh(x)) output_float64(arccosh(x))
def test_arctanh(): def test_arctanh():
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arctanh(x)) output_float64(arctanh(x))
def test_expm1(): def test_expm1():
for x in [0.0, 1.0]: for x in [0.0, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(expm1(x)) output_float64(expm1(x))
def test_cbrt(): def test_cbrt():
for x in [1.0, 8.0, 27.0]: for x in [1.0, 8.0, 27.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(expm1(x)) output_float64(expm1(x))
def test_erf(): def test_erf():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]: for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(erf(x)) output_float64(erf(x))
def test_erfc(): def test_erfc():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]: for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(erfc(x)) output_float64(erfc(x))
def test_gamma(): def test_gamma():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(gamma(x)) output_float64(gamma(x))
def test_gammaln(): def test_gammaln():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(gammaln(x)) output_float64(gammaln(x))
def test_j0(): def test_j0():
for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(j0(x)) output_float64(j0(x))
def test_j1(): def test_j1():
@ -151,38 +157,38 @@ def test_j1():
output_float64(j1(x)) output_float64(j1(x))
def test_arctan2(): def test_arctan2():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(arctan2(x1, x2)) output_float64(arctan2(x1, x2))
def test_copysign(): def test_copysign():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(copysign(x1, x2)) output_float64(copysign(x1, x2))
def test_fmax(): def test_fmax():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(fmax(x1, x2)) output_float64(fmax(x1, x2))
def test_fmin(): def test_fmin():
for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(fmin(x1, x2)) output_float64(fmin(x1, x2))
def test_ldexp(): def test_ldexp():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-2, -1, 0, 1, 2]: for x2 in [-2, -1, 0, 1, 2]:
output_float64(ldexp(x1, x2)) output_float64(ldexp(x1, x2))
def test_hypot(): def test_hypot():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(hypot(x1, x2)) output_float64(hypot(x1, x2))
def test_nextafter(): def test_nextafter():
for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(nextafter(x1, x2)) output_float64(nextafter(x1, x2))
def run() -> int32: def run() -> int32: