From 5c5620692f0695b05357e0f1beb8d33afb0d6c16 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 23 Nov 2023 13:45:07 +0800 Subject: [PATCH] core: Add np_{round,floor,ceil} These functions are NumPy variants of round/floor/ceil, which returns floats instead of ints. --- nac3core/src/toplevel/builtins.rs | 84 +++++++++++++++++++++++++++ nac3standalone/demo/interpret_demo.py | 3 + nac3standalone/demo/src/math.py | 15 +++++ 3 files changed, 102 insertions(+) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 84b2cd5..0b28bea 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -881,6 +881,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Ok(Some(val_toint.into())) }), ), + create_fn_by_codegen( + primitives, + &var_map, + "np_round", + float, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.roundeven.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.roundeven.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some(val.into())) + }), + ), Arc::new(RwLock::new(TopLevelDef::Function { name: "range".into(), simple_name: "range".into(), @@ -1123,6 +1150,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Ok(Some(val_toint.into())) }), ), + create_fn_by_codegen( + primitives, + &var_map, + "np_floor", + float, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.floor.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some(val.into())) + }), + ), create_fn_by_codegen( primitives, &var_map, @@ -1183,6 +1237,33 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Ok(Some(val_toint.into())) }), ), + create_fn_by_codegen( + primitives, + &var_map, + "np_ceil", + float, + &[(float, "n")], + Box::new(|ctx, _, _, args, generator| { + let llvm_f64 = ctx.ctx.f64_type(); + + let arg = args[0].1.clone() + .to_basic_value_enum(ctx, generator, ctx.primitives.float)?; + + let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.ceil.f64", fn_type, None) + }); + + let val = ctx + .builder + .build_call(intrinsic_fn, &[arg.into()], "") + .try_as_basic_value() + .left() + .unwrap(); + Ok(Some(val.into())) + }), + ), Arc::new(RwLock::new({ let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); @@ -1835,13 +1916,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "float", "round", "round64", + "np_round", "range", "str", "bool", "floor", "floor64", + "np_floor", "ceil", "ceil64", + "np_ceil", "len", "min", "max", diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index e52544e..9a753d0 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -108,10 +108,13 @@ def patch(module): # Builtin Math functions module.round = round_away_zero module.round64 = round_away_zero + module.np_round = np.round module.floor = math.floor module.floor64 = math.floor + module.np_floor = np.floor module.ceil = math.ceil module.ceil64 = math.ceil + module.np_ceil = np.ceil # NumPy Math functions module.np_isnan = np.isnan diff --git a/nac3standalone/demo/src/math.py b/nac3standalone/demo/src/math.py index 5efe1e8..1cd3e3d 100644 --- a/nac3standalone/demo/src/math.py +++ b/nac3standalone/demo/src/math.py @@ -36,6 +36,10 @@ def test_round64(): for x in [-1.5, -0.5, 0.5, 1.5]: output_int64(round64(x)) +def test_np_round(): + for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: + output_float64(np_round(x)) + def test_np_isnan(): for x in [dbl_nan(), 0.0, dbl_inf()]: output_bool(np_isnan(x)) @@ -87,6 +91,10 @@ def test_floor64(): for x in [-1.5, -0.5, 0.5, 1.5]: output_int64(floor64(x)) +def test_np_floor(): + for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: + output_float64(np_floor(x)) + def test_ceil(): for x in [-1.5, -0.5, 0.5, 1.5]: output_int32(ceil(x)) @@ -95,6 +103,10 @@ def test_ceil64(): for x in [-1.5, -0.5, 0.5, 1.5]: output_int64(ceil64(x)) +def test_np_ceil(): + for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: + output_float64(np_ceil(x)) + def test_np_sqrt(): for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]: output_float64(np_sqrt(x)) @@ -214,6 +226,7 @@ def test_np_nextafter(): def run() -> int32: test_round() test_round64() + test_np_round() test_np_isnan() test_np_isinf() test_np_sin() @@ -226,8 +239,10 @@ def run() -> int32: test_np_fabs() test_floor() test_floor64() + test_np_floor() test_ceil() test_ceil64() + test_np_ceil() test_np_sqrt() test_np_rint() test_np_tan()