From cdaf554736302594613bda6850fda25192f61e0c Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Sat, 13 Apr 2019 00:43:45 +0100 Subject: [PATCH] compiler: Implement abs() for scalars GitHub: Fixes #1303. --- artiq/compiler/builtins.py | 3 +++ artiq/compiler/prelude.py | 1 + .../compiler/transforms/artiq_ir_generator.py | 10 +++++++++ artiq/compiler/transforms/inferencer.py | 22 +++++++++++++++++++ artiq/test/lit/inferencer/builtin_calls.py | 6 +++++ .../lit/inferencer/error_builtin_calls.py | 3 +++ artiq/test/lit/integration/abs.py | 7 ++++++ 7 files changed, 52 insertions(+) create mode 100644 artiq/test/lit/integration/abs.py diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index d60db0840..ce2e23e27 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -181,6 +181,9 @@ def fn_len(): def fn_round(): return types.TBuiltinFunction("round") +def fn_abs(): + return types.TBuiltinFunction("abs") + def fn_min(): return types.TBuiltinFunction("min") diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index 24a7bd1fa..ee47168ac 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -29,6 +29,7 @@ def globals(): # Built-in Python functions "len": builtins.fn_len(), "round": builtins.fn_round(), + "abs": builtins.fn_abs(), "min": builtins.fn_min(), "max": builtins.fn_max(), "print": builtins.fn_print(), diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 1d45d690c..4e3c3394b 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1700,6 +1700,16 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Builtin("round", [arg], node.type)) else: assert False + elif types.is_builtin(typ, "abs"): + if len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + neg = self.append( + ir.Arith(ast.Sub(loc=None), ir.Constant(0, arg.type), arg)) + cond = self.append( + ir.Compare(ast.Lt(loc=None), arg, ir.Constant(0, arg.type))) + return self.append(ir.Select(cond, neg, arg)) + else: + assert False elif types.is_builtin(typ, "min"): if len(node.args) == 2 and len(node.keywords) == 0: arg0, arg1 = map(self.visit, node.args) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 9142658b3..2f5d1800c 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -811,6 +811,28 @@ class Inferencer(algorithm.Visitor): arg.loc, None) else: diagnose(valid_forms()) + elif types.is_builtin(typ, "abs"): + fn = typ.name + + valid_forms = lambda: [ + valid_form("abs(x:numpy.int?) -> numpy.int?"), + valid_form("abs(x:float) -> float") + ] + + if len(node.args) == 1 and len(node.keywords) == 0: + (arg,) = node.args + if builtins.is_int(arg.type) or builtins.is_float(arg.type): + self._unify(arg.type, node.type, + arg.loc, node.loc) + elif types.is_var(arg.type): + pass # undetermined yet + else: + diag = diagnostic.Diagnostic("error", + "the arguments of abs() must be of a numeric type", {}, + node.func.loc) + self.engine.process(diag) + else: + diagnose(valid_forms()) elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"): fn = typ.name diff --git a/artiq/test/lit/inferencer/builtin_calls.py b/artiq/test/lit/inferencer/builtin_calls.py index be1b797d9..a4b2f81fe 100644 --- a/artiq/test/lit/inferencer/builtin_calls.py +++ b/artiq/test/lit/inferencer/builtin_calls.py @@ -30,3 +30,9 @@ len([]) # CHECK-L: round:(1.0:float):numpy.int? round(1.0) + +# CHECK-L: abs:(1:numpy.int?):numpy.int? +abs(1) + +# CHECK-L: abs:(1.0:float):float +abs(1.0) diff --git a/artiq/test/lit/inferencer/error_builtin_calls.py b/artiq/test/lit/inferencer/error_builtin_calls.py index bb1b8ca04..643011f2d 100644 --- a/artiq/test/lit/inferencer/error_builtin_calls.py +++ b/artiq/test/lit/inferencer/error_builtin_calls.py @@ -10,5 +10,8 @@ list(1) # CHECK-L: ${LINE:+1}: error: the arguments of min() must be of a numeric type min([1], [1]) +# CHECK-L: ${LINE:+1}: error: the arguments of abs() must be of a numeric type +abs([1.0]) + # CHECK-L: ${LINE:+1}: error: strings currently cannot be constructed str(1) diff --git a/artiq/test/lit/integration/abs.py b/artiq/test/lit/integration/abs.py new file mode 100644 index 000000000..877cf6c9a --- /dev/null +++ b/artiq/test/lit/integration/abs.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +assert abs(1234) == 1234 +assert abs(-1234) == 1234 +assert abs(1234.0) == 1234.0 +assert abs(-1234.0) == 1234