1
0
forked from M-Labs/nac3

core: Implement and expose {isinf,isnan}

This commit is contained in:
David Mak 2023-10-10 16:56:38 +08:00
parent 316f0824d8
commit 60ad100fbb
6 changed files with 142 additions and 3 deletions

View File

@ -138,3 +138,11 @@ int32_t __nac3_list_slice_assign_var_size(
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}

View File

@ -7,7 +7,7 @@ use inkwell::{
memory_buffer::MemoryBuffer,
module::Module,
types::BasicTypeEnum,
values::{IntValue, PointerValue},
values::{FloatValue, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use nac3parser::ast::Expr;
@ -432,3 +432,43 @@ pub fn list_slice_assignment<'ctx, 'a>(
ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb);
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, 'a>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.try_as_basic_value()
.unwrap_left()
.into_int_value();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, 'a>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.try_as_basic_value()
.unwrap_left()
.into_int_value();
generator.bool_to_i1(ctx, ret)
}

View File

@ -1,7 +1,9 @@
use super::*;
use crate::{
codegen::{
expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor,
expr::destructure_range,
irrt::{calculate_len_for_slice_range, call_isinf, call_isnan},
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
};
@ -1226,6 +1228,46 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
create_fn_by_codegen(
primitives,
&var_map,
"isnan",
boolean,
&[(float, "x")],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let val = call_isnan(generator, ctx, x_val.into_float_value());
Ok(Some(val.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"isinf",
boolean,
&[(float, "x")],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let val = call_isinf(generator, ctx, x_val.into_float_value());
Ok(Some(val.into()))
}),
),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "Some".into(),
simple_name: "Some".into(),
@ -1270,6 +1312,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"min",
"max",
"abs",
"isnan",
"isinf",
"Some",
],
)

View File

@ -13,6 +13,14 @@
#error "Unsupported platform - Platform is not 32-bit or 64-bit"
#endif
double dbl_nan(void) {
return NAN;
}
double dbl_inf(void) {
return INFINITY;
}
void output_bool(bool x) {
puts(x ? "True" : "False");
}

View File

@ -3,6 +3,7 @@
import sys
import importlib.util
import importlib.machinery
import numpy as np
import pathlib
from numpy import int32, int64, uint32, uint64
@ -42,6 +43,12 @@ def Some(v: T) -> Option[T]:
none = Option(None)
def patch(module):
def dbl_nan():
return np.nan
def dbl_inf():
return np.inf
def output_asciiart(x):
if x < 0:
sys.stdout.write("\n")
@ -56,7 +63,11 @@ def patch(module):
def extern(fun):
name = fun.__name__
if name == "output_asciiart":
if name == "dbl_nan":
return dbl_nan
elif name == "dbl_inf":
return dbl_inf
elif name == "output_asciiart":
return output_asciiart
elif name == "output_float64":
return output_float
@ -86,6 +97,9 @@ def patch(module):
module.Some = Some
module.none = none
module.isnan = np.isnan
module.isinf = np.isinf
def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename)

View File

@ -0,0 +1,25 @@
@extern
def output_bool(x: bool):
...
@extern
def dbl_nan() -> float:
...
@extern
def dbl_inf() -> float:
...
def test_isnan():
for x in [dbl_nan(), 0.0, dbl_inf()]:
output_bool(isnan(x))
def test_isinf():
for x in [dbl_inf(), 0.0, dbl_nan()]:
output_bool(isinf(x))
def run() -> int32:
test_isnan()
test_isinf()
return 0