From 2b635a0b974adb5cdd72d9da26e725bbd620bd39 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 6 Oct 2023 17:48:31 +0800 Subject: [PATCH] core: Implement numpy and scipy functions --- nac3core/src/toplevel/builtins.rs | 463 ++++++++++++++++++++++---- nac3standalone/demo/interpret_demo.py | 42 +++ nac3standalone/demo/src/math.py | 205 ++++++++++++ 3 files changed, 646 insertions(+), 64 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 6f3488de..13db0931 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -919,70 +919,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "floor".into(), - simple_name: "floor".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], - ret: int32, - vars: Default::default(), - })), - var_id: Default::default(), - instance_to_symbol: Default::default(), - instance_to_stmt: Default::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - let floor_intrinsic = - ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.floor.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(floor_intrinsic, &[arg.into()], "floor") - .try_as_basic_value() - .left() - .unwrap(); - Ok(val.into()) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "ceil".into(), - simple_name: "ceil".into(), - signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: float, default_value: None }], - ret: int32, - vars: Default::default(), - })), - var_id: Default::default(), - instance_to_symbol: Default::default(), - instance_to_stmt: Default::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, _, args, generator| { - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, ctx.primitives.float)?; - let ceil_intrinsic = - ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| { - let float = ctx.ctx.f64_type(); - let fn_type = float.fn_type(&[float.into()], false); - ctx.module.add_function("llvm.ceil.f64", fn_type, None) - }); - let val = ctx - .builder - .build_call(ceil_intrinsic, &[arg.into()], "ceil") - .try_as_basic_value() - .left() - .unwrap(); - Ok(val.into()) - }, - )))), - loc: None, - })), + create_fn_by_intrinsic( + primitives, + &var_map, + "floor", + float, + &[(float, "x")], + "llvm.floor.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "ceil", + float, + &[(float, "x")], + "llvm.ceil.f64", + ), Arc::new(RwLock::new({ let list_var = primitives.1.get_fresh_var(Some("L".into()), None); let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); @@ -1268,6 +1220,353 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Ok(Some(val.into())) }), ), + create_fn_by_intrinsic( + primitives, + &var_map, + "sin", + float, + &[(float, "x")], + "llvm.sin.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "cos", + float, + &[(float, "x")], + "llvm.cos.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "exp", + float, + &[(float, "x")], + "llvm.exp.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "exp2", + float, + &[(float, "x")], + "llvm.exp2.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "log", + float, + &[(float, "x")], + "llvm.log.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "log10", + float, + &[(float, "x")], + "llvm.log10.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "log2", + float, + &[(float, "x")], + "llvm.log2.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "fabs", + float, + &[(float, "x")], + "llvm.fabs.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "trunc", + float, + &[(float, "x")], + "llvm.trunc.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "sqrt", + float, + &[(float, "x")], + "llvm.sqrt.f64", + ), + create_fn_by_codegen( + primitives, + &var_map, + "rint", + float, + &[(float, "x")], + Box::new(|ctx, _, fun, args, generator| { + let float = ctx.primitives.float; + let llvm_f64 = ctx.ctx.f64_type(); + + let x_ty = fun.0.args[0].ty; + let x_val = args[0].1.clone() + .to_basic_value_enum(ctx, generator, x_ty)?; + + assert!(ctx.unifier.unioned(x_ty, float)); + + let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + + ctx.module.add_function("llvm.round.f64", fn_type, None) + }); + + // rint(x) == round(x * 0.5) * 2.0 + + // %0 = fmul f64 %x, 0.5 + let x_half = ctx.builder + .build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), ""); + // %1 = call f64 @llvm.round.f64(f64 %0) + let round = ctx.builder + .build_call( + intrinsic_fn, + &vec![x_half.into()], + "", + ) + .try_as_basic_value() + .left() + .unwrap(); + // %2 = fmul f64 %1, 2.0 + let val = ctx.builder + .build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint"); + + Ok(Some(val.into())) + }), + ), + create_fn_by_extern( + primitives, + &var_map, + "tan", + float, + &[(float, "x")], + "tan", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "arcsin", + float, + &[(float, "x")], + "asin", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "arccos", + float, + &[(float, "x")], + "acos", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "arctan", + float, + &[(float, "x")], + "atan", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "sinh", + float, + &[(float, "x")], + "sinh", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "cosh", + float, + &[(float, "x")], + "cosh", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "tanh", + float, + &[(float, "x")], + "tanh", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "arcsinh", + float, + &[(float, "x")], + "asinh", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "arccosh", + float, + &[(float, "x")], + "acosh", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "arctanh", + float, + &[(float, "x")], + "atanh", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "expm1", + float, + &[(float, "x")], + "expm1", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "cbrt", + float, + &[(float, "x")], + "cbrt", + &["readnone", "willreturn"], + ), + create_fn_by_extern( + primitives, + &var_map, + "erf", + float, + &[(float, "z")], + "erf", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "erfc", + float, + &[(float, "x")], + "erfc", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "gamma", + float, + &[(float, "z")], + "tgamma", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "gammaln", + float, + &[(float, "x")], + "lgamma", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "j0", + float, + &[(float, "x")], + "j0", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "j1", + float, + &[(float, "x")], + "j1", + &[], + ), + // Not mapped: jv/yv, libm only supports integer orders. + create_fn_by_extern( + primitives, + &var_map, + "arctan2", + float, + &[(float, "x1"), (float, "x2")], + "atan2", + &[], + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "copysign", + float, + &[(float, "x1"), (float, "x2")], + "llvm.copysign.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "fmax", + float, + &[(float, "x1"), (float, "x2")], + "llvm.maxnum.f64", + ), + create_fn_by_intrinsic( + primitives, + &var_map, + "fmin", + float, + &[(float, "x1"), (float, "x2")], + "llvm.minnum.f64", + ), + create_fn_by_extern( + primitives, + &var_map, + "ldexp", + float, + &[(float, "x1"), (int32, "x2")], + "ldexp", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "hypot", + float, + &[(float, "x1"), (float, "x2")], + "hypot", + &[], + ), + create_fn_by_extern( + primitives, + &var_map, + "nextafter", + float, + &[(float, "x1"), (float, "x2")], + "nextafter", + &[], + ), Arc::new(RwLock::new(TopLevelDef::Function { name: "Some".into(), simple_name: "Some".into(), @@ -1314,6 +1613,42 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "abs", "isnan", "isinf", + "sin", + "cos", + "exp", + "exp2", + "log", + "log10", + "log2", + "fabs", + "trunc", + "sqrt", + "rint", + "tan", + "arcsin", + "arccos", + "arctan", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + "expm1", + "cbrt", + "erf", + "erfc", + "gamma", + "gammaln", + "j0", + "j1", + "arctan2", + "copysign", + "fmax", + "fmin", + "ldexp", + "hypot", + "nextafter", "Some", ], ) diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 9c62e588..03392065 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -5,6 +5,7 @@ import importlib.util import importlib.machinery import numpy as np import pathlib +import scipy from numpy import int32, int64, uint32, uint64 from typing import TypeVar, Generic @@ -97,8 +98,49 @@ def patch(module): module.Some = Some module.none = none + # NumPy Math functions module.isnan = np.isnan module.isinf = np.isinf + module.sin = np.sin + module.cos = np.cos + module.exp = np.exp + module.exp2 = np.exp2 + module.log = np.log + module.log10 = np.log10 + module.log2 = np.log2 + module.fabs = np.fabs + module.floor = np.floor + module.ceil = np.ceil + module.trunc = np.trunc + module.sqrt = np.sqrt + module.rint = np.rint + module.tan = np.tan + module.arcsin = np.arcsin + module.arccos = np.arccos + module.arctan = np.arctan + module.sinh = np.sinh + module.cosh = np.cosh + module.tanh = np.tanh + module.arcsinh = np.arcsinh + module.arccosh = np.arccosh + module.arctanh = np.arctanh + module.expm1 = np.expm1 + module.cbrt = np.cbrt + module.arctan2 = np.arctan2 + module.copysign = np.copysign + module.fmax = np.fmax + module.fmin = np.fmin + module.ldexp = np.ldexp + module.hypot = np.hypot + module.nextafter = np.nextafter + + # SciPy Math Functions + module.erf = scipy.special.erf + module.erfc = scipy.special.erfc + module.gamma = scipy.special.gamma + module.gammaln = scipy.special.gammaln + module.j0 = scipy.special.j0 + module.j1 = scipy.special.j1 def file_import(filename, prefix="file_import_"): diff --git a/nac3standalone/demo/src/math.py b/nac3standalone/demo/src/math.py index 3d7b6807..a55e3de9 100644 --- a/nac3standalone/demo/src/math.py +++ b/nac3standalone/demo/src/math.py @@ -2,6 +2,10 @@ def output_bool(x: bool): ... +@extern +def output_float64(x: float): + ... + @extern def dbl_nan() -> float: ... @@ -18,8 +22,209 @@ def test_isinf(): for x in [dbl_inf(), 0.0, dbl_nan()]: output_bool(isinf(x)) +def test_sin(): + pi = 3.1415926535897932384626433 + for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]: + output_float64(sin(x)) + +def test_cos(): + pi = 3.1415926535897932384626433 + for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]: + output_float64(cos(x)) + +def test_exp(): + for x in [0.0, 1.0]: + output_float64(exp(x)) + +def test_exp2(): + for x in [0.0, 1.0]: + output_float64(exp2(x)) + +def test_log(): + e = 2.71828182845904523536028747135266249775724709369995 + for x in [1.0, e]: + output_float64(log(x)) + +def test_log10(): + for x in [1.0, 10.0]: + output_float64(log10(x)) + +def test_log2(): + for x in [1.0, 2.0]: + output_float64(log2(x)) + +def test_fabs(): + for x in [-1.0, 0.0, 1.0]: + output_float64(fabs(x)) + +def test_floor(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_float64(floor(x)) + +def test_ceil(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_float64(ceil(x)) + +def test_trunc(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_float64(trunc(x)) + +def test_sqrt(): + for x in [1.0, 2.0, 4.0]: + output_float64(sqrt(x)) + +def test_rint(): + for x in [-1.5, -0.5, 0.5, 1.5]: + output_float64(rint(x)) + +def test_tan(): + pi = 3.1415926535897932384626433 + for x in [-pi, -pi / 2.0, -pi / 4.0, 0.0, pi / 4.0, pi / 2.0, pi]: + output_float64(tan(x)) + +def test_arcsin(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arcsin(x)) + +def test_arccos(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arccos(x)) + +def test_arctan(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arctan(x)) + +def test_sinh(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(sinh(x)) + +def test_cosh(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(cosh(x)) + +def test_tanh(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(tanh(x)) + +def test_arcsinh(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arcsinh(x)) + +def test_arccosh(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arccosh(x)) + +def test_arctanh(): + for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arctanh(x)) + +def test_expm1(): + for x in [0.0, 1.0]: + output_float64(expm1(x)) + +def test_cbrt(): + for x in [1.0, 8.0, 27.0]: + output_float64(expm1(x)) + +def test_erf(): + for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]: + output_float64(erf(x)) + +def test_erfc(): + for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]: + output_float64(erfc(x)) + +def test_gamma(): + for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + output_float64(gamma(x)) + +def test_gammaln(): + for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + output_float64(gammaln(x)) + +def test_j0(): + for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + output_float64(j0(x)) + +def test_j1(): + for x in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + output_float64(j1(x)) + +def test_arctan2(): + for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(arctan2(x1, x2)) + +def test_copysign(): + for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(copysign(x1, x2)) + +def test_fmax(): + for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(fmax(x1, x2)) + +def test_fmin(): + for x1 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + for x2 in [-1.0, -0.5, 0.0, 0.5, 1.0]: + output_float64(fmin(x1, x2)) + +def test_ldexp(): + for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + for x2 in [-2, -1, 0, 1, 2]: + output_float64(ldexp(x1, x2)) + +def test_hypot(): + for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + output_float64(hypot(x1, x2)) + +def test_nextafter(): + for x1 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + for x2 in [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]: + output_float64(nextafter(x1, x2)) + def run() -> int32: test_isnan() test_isinf() + test_sin() + test_cos() + test_exp() + test_exp2() + test_log() + test_log10() + test_log2() + test_fabs() + test_floor() + test_ceil() + test_trunc() + test_sqrt() + test_rint() + test_tan() + test_arcsin() + test_arccos() + test_arctan() + test_sinh() + test_cosh() + test_tanh() + test_arcsinh() + test_arccosh() + test_arctanh() + test_expm1() + test_cbrt() + test_erf() + test_erfc() + test_gamma() + test_gammaln() + test_j0() + test_j1() + test_arctan2() + test_copysign() + test_fmax() + test_fmin() + test_ldexp() + test_hypot() + test_nextafter() return 0 \ No newline at end of file