diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 7b2f9b67..8c0cd1c4 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -737,6 +737,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + create_fn_by_codegen( + primitives, + &var_map, + "round", + int32, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i32, "round"); + Ok(Some(val_toint.into())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "round64", + int64, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i64 = ctx.ctx.i64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i64, "round"); + Ok(Some(val_toint.into())) + }), + ), Arc::new(RwLock::new(TopLevelDef::Function { name: "range".into(), simple_name: "range".into(), @@ -919,21 +979,125 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), - create_fn_by_intrinsic( + create_fn_by_codegen( primitives, &var_map, "floor", - float, - &[(float, "x")], - "llvm.floor.f64", + int32, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor"); + Ok(Some(val_toint.into())) + }), ), - create_fn_by_intrinsic( + create_fn_by_codegen( + primitives, + &var_map, + "floor64", + int64, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i64 = ctx.ctx.i64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor"); + Ok(Some(val_toint.into())) + }), + ), + create_fn_by_codegen( primitives, &var_map, "ceil", - float, - &[(float, "x")], - "llvm.ceil.f64", + int32, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil"); + Ok(Some(val_toint.into())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "ceil64", + int64, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i64 = ctx.ctx.i64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil"); + Ok(Some(val_toint.into())) + }), ), Arc::new(RwLock::new({ let list_var = primitives.1.get_fresh_var(Some("L".into()), None); @@ -1284,14 +1448,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { &[(float, "x")], "llvm.fabs.f64", ), - create_fn_by_intrinsic( - primitives, - &var_map, - "trunc", - float, - &[(float, "x")], - "llvm.trunc.f64", - ), create_fn_by_intrinsic( primitives, &var_map, @@ -1300,49 +1456,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { &[(float, "x")], "llvm.sqrt.f64", ), - create_fn_by_codegen( + create_fn_by_intrinsic( primitives, &var_map, "rint", float, &[(float, "x")], - Box::new(|ctx, _, fun, args, generator| { - let float = ctx.primitives.float; - let llvm_f64 = ctx.ctx.f64_type(); - - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - assert!(ctx.unifier.unioned(x_ty, float)); - - let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - - ctx.module.add_function("llvm.round.f64", fn_type, None) - }); - - // rint(x) == round(x * 0.5) * 2.0 - - // %0 = fmul f64 %x, 0.5 - let x_half = ctx.builder - .build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), ""); - // %1 = call f64 @llvm.round.f64(f64 %0) - let round = ctx.builder - .build_call( - intrinsic_fn, - &vec![x_half.into()], - "", - ) - .try_as_basic_value() - .left() - .unwrap(); - // %2 = fmul f64 %1, 2.0 - let val = ctx.builder - .build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint"); - - Ok(Some(val.into())) - }), + "llvm.roundeven.f64", ), create_fn_by_extern( primitives, @@ -1774,11 +1894,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "uint32", "uint64", "float", + "round", + "round64", "range", "str", "bool", "floor", + "floor64", "ceil", + "ceil64", "len", "min", "max", @@ -1793,7 +1917,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "log10", "log2", "fabs", - "trunc", "sqrt", "rint", "tan", diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 03392065..e6ccba18 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -3,6 +3,7 @@ import sys import importlib.util import importlib.machinery +import math import numpy as np import pathlib import scipy @@ -43,6 +44,12 @@ def Some(v: T) -> Option[T]: none = Option(None) +def round_away_zero(x): + if x >= 0.0: + return math.floor(x + 0.5) + else: + return math.ceil(x - 0.5) + def patch(module): def dbl_nan(): return np.nan @@ -98,6 +105,14 @@ def patch(module): module.Some = Some module.none = none + # Builtin Math functions + module.round = round_away_zero + module.round64 = round_away_zero + module.floor = math.floor + module.floor64 = math.floor + module.ceil = math.ceil + module.ceil64 = math.ceil + # NumPy Math functions module.isnan = np.isnan module.isinf = np.isinf @@ -109,8 +124,6 @@ def patch(module): module.log10 = np.log10 module.log2 = np.log2 module.fabs = np.fabs - module.floor = np.floor - module.ceil = np.ceil module.trunc = np.trunc module.sqrt = np.sqrt module.rint = np.rint diff --git a/nac3standalone/demo/src/math.py b/nac3standalone/demo/src/math.py index 9fa39514..52d62749 100644 --- a/nac3standalone/demo/src/math.py +++ b/nac3standalone/demo/src/math.py @@ -2,6 +2,14 @@ def output_bool(x: bool): ... +@extern +def output_int32(x: int32): + ... + +@extern +def output_int64(x: int64): + ... + @extern def output_float64(x: float): ... @@ -20,6 +28,14 @@ def dbl_pi() -> float: def dbl_e() -> float: return 2.71828182845904523536028747135266249775724709369995 +def test_round(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int32(round(x)) + +def test_round64(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int64(round64(x)) + def test_isnan(): for x in [dbl_nan(), 0.0, dbl_inf()]: output_bool(isnan(x)) @@ -64,16 +80,20 @@ def test_fabs(): output_float64(fabs(x)) def test_floor(): - for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: - output_float64(floor(x)) + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int32(floor(x)) + +def test_floor64(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int64(floor64(x)) def test_ceil(): - for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: - output_float64(ceil(x)) + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int32(ceil(x)) -def test_trunc(): - for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: - output_float64(trunc(x)) +def test_ceil64(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int64(ceil64(x)) def test_sqrt(): for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]: @@ -192,6 +212,8 @@ def test_nextafter(): output_float64(nextafter(x1, x2)) def run() -> int32: + test_round() + test_round64() test_isnan() test_isinf() test_sin() @@ -203,8 +225,9 @@ def run() -> int32: test_log2() test_fabs() test_floor() + test_floor64() test_ceil() - test_trunc() + test_ceil64() test_sqrt() test_rint() test_tan()