From 4dbe07a0c0a4720e04c9162b63655383f692dab3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 2 Nov 2023 14:56:35 +0800 Subject: [PATCH] core: Revert breaking changes to round-family functions These functions should return ints as the math.* functions do instead of following the convention of numpy.* functions. --- nac3core/src/toplevel/builtins.rs | 184 ++++++++++++++++++++++++-- nac3standalone/demo/interpret_demo.py | 17 ++- nac3standalone/demo/src/math.py | 36 ++++- 3 files changed, 223 insertions(+), 14 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 319ae207..a36939a2 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -814,6 +814,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + create_fn_by_codegen( + primitives, + &var_map, + "round", + int32, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i32, "round"); + Ok(Some(val_toint.into())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "round64", + int64, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i64 = ctx.ctx.i64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i64, "round"); + Ok(Some(val_toint.into())) + }), + ), Arc::new(RwLock::new(TopLevelDef::Function { name: "range".into(), simple_name: "range".into(), @@ -996,21 +1056,125 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), - create_fn_by_intrinsic( + create_fn_by_codegen( primitives, &var_map, "floor", - float, - &[(float, "x")], - "llvm.floor.f64", + int32, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor"); + Ok(Some(val_toint.into())) + }), ), - create_fn_by_intrinsic( + create_fn_by_codegen( + primitives, + &var_map, + "floor64", + int64, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i64 = ctx.ctx.i64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor"); + Ok(Some(val_toint.into())) + }), + ), + create_fn_by_codegen( primitives, &var_map, "ceil", - float, - &[(float, "x")], - "llvm.ceil.f64", + int32, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil"); + Ok(Some(val_toint.into())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "ceil64", + int64, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + let llvm_i64 = ctx.ctx.i64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + let val_toint = ctx.builder + .build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil"); + Ok(Some(val_toint.into())) + }), ), Arc::new(RwLock::new({ let list_var = primitives.1.get_fresh_var(Some("L".into()), None); @@ -1815,11 +1979,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "uint32", "uint64", "float", + "round", + "round64", "range", "str", "bool", "floor", + "floor64", "ceil", + "ceil64", "len", "min", "max", diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 03392065..e6ccba18 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -3,6 +3,7 @@ import sys import importlib.util import importlib.machinery +import math import numpy as np import pathlib import scipy @@ -43,6 +44,12 @@ def Some(v: T) -> Option[T]: none = Option(None) +def round_away_zero(x): + if x >= 0.0: + return math.floor(x + 0.5) + else: + return math.ceil(x - 0.5) + def patch(module): def dbl_nan(): return np.nan @@ -98,6 +105,14 @@ def patch(module): module.Some = Some module.none = none + # Builtin Math functions + module.round = round_away_zero + module.round64 = round_away_zero + module.floor = math.floor + module.floor64 = math.floor + module.ceil = math.ceil + module.ceil64 = math.ceil + # NumPy Math functions module.isnan = np.isnan module.isinf = np.isinf @@ -109,8 +124,6 @@ def patch(module): module.log10 = np.log10 module.log2 = np.log2 module.fabs = np.fabs - module.floor = np.floor - module.ceil = np.ceil module.trunc = np.trunc module.sqrt = np.sqrt module.rint = np.rint diff --git a/nac3standalone/demo/src/math.py b/nac3standalone/demo/src/math.py index 9fa39514..dfbd4cae 100644 --- a/nac3standalone/demo/src/math.py +++ b/nac3standalone/demo/src/math.py @@ -2,6 +2,14 @@ def output_bool(x: bool): ... +@extern +def output_int32(x: int32): + ... + +@extern +def output_int64(x: int64): + ... + @extern def output_float64(x: float): ... @@ -20,6 +28,14 @@ def dbl_pi() -> float: def dbl_e() -> float: return 2.71828182845904523536028747135266249775724709369995 +def test_round(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int32(round(x)) + +def test_round64(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int64(round64(x)) + def test_isnan(): for x in [dbl_nan(), 0.0, dbl_inf()]: output_bool(isnan(x)) @@ -64,12 +80,20 @@ def test_fabs(): output_float64(fabs(x)) def test_floor(): - for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: - output_float64(floor(x)) + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int32(floor(x)) + +def test_floor64(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int64(floor64(x)) def test_ceil(): - for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: - output_float64(ceil(x)) + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int32(ceil(x)) + +def test_ceil64(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_int64(ceil64(x)) def test_trunc(): for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: @@ -192,6 +216,8 @@ def test_nextafter(): output_float64(nextafter(x1, x2)) def run() -> int32: + test_round() + test_round64() test_isnan() test_isinf() test_sin() @@ -203,7 +229,9 @@ def run() -> int32: test_log2() test_fabs() test_floor() + test_floor64() test_ceil() + test_ceil64() test_trunc() test_sqrt() test_rint()