From 33e8e59cc743a174cedf671d64c0ebff3d8c3da6 Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 22 Jun 2016 01:09:41 +0000 Subject: [PATCH] compiler: implement min()/max() as builtins. Fixes #239. --- artiq/compiler/builtins.py | 6 ++++ artiq/compiler/prelude.py | 2 ++ .../compiler/transforms/artiq_ir_generator.py | 14 +++++++++ artiq/compiler/transforms/inferencer.py | 31 +++++++++++++++++++ .../lit/inferencer/error_builtin_calls.py | 3 ++ artiq/test/lit/integration/minmax.py | 7 +++++ 6 files changed, 63 insertions(+) create mode 100644 artiq/test/lit/integration/minmax.py diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index c06f600d3..0d091e74f 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -138,6 +138,12 @@ def fn_len(): def fn_round(): return types.TBuiltinFunction("round") +def fn_min(): + return types.TBuiltinFunction("min") + +def fn_max(): + return types.TBuiltinFunction("max") + def fn_print(): return types.TBuiltinFunction("print") diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index 98dd02ff0..8809f7612 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -23,6 +23,8 @@ def globals(): # Built-in Python functions "len": builtins.fn_len(), "round": builtins.fn_round(), + "min": builtins.fn_min(), + "max": builtins.fn_max(), "print": builtins.fn_print(), # ARTIQ decorators diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 35aa25616..5db5809e5 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1667,6 +1667,20 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Builtin("round", [arg], node.type)) else: assert False + elif types.is_builtin(typ, "min"): + if len(node.args) == 2 and len(node.keywords) == 0: + arg0, arg1 = map(self.visit, node.args) + cond = self.append(ir.Compare(ast.Lt(loc=None), arg0, arg1)) + return self.append(ir.Select(cond, arg0, arg1)) + else: + assert False + elif types.is_builtin(typ, "max"): + if len(node.args) == 2 and len(node.keywords) == 0: + arg0, arg1 = map(self.visit, node.args) + cond = self.append(ir.Compare(ast.Gt(loc=None), arg0, arg1)) + return self.append(ir.Select(cond, arg0, arg1)) + else: + assert False elif types.is_builtin(typ, "print"): self.polymorphic_print([self.visit(arg) for arg in node.args], separator=" ", suffix="\n") diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 329508812..59d215c84 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -789,6 +789,37 @@ class Inferencer(algorithm.Visitor): node.loc, None) else: diagnose(valid_forms()) + elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"): + fn = typ.name + + valid_forms = lambda: [ + valid_form("{}(x:int(width='a), y:int(width='a)) -> int(width='a)".format(fn)), + valid_form("{}(x:float, y:float) -> float".format(fn)) + ] + + if len(node.args) == 2 and len(node.keywords) == 0: + arg0, arg1 = node.args + + self._unify(arg0.type, arg1.type, + arg0.loc, arg1.loc) + + if builtins.is_int(arg0.type) or builtins.is_float(arg0.type): + self._unify(arg0.type, node.type, + arg0.loc, node.loc) + elif types.is_var(arg0.type): + pass # undetermined yet + else: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(arg0.type)}, + arg0.loc) + diag = diagnostic.Diagnostic("error", + "the arguments of {fn}() must be of a numeric type", + {"fn": fn}, + node.func.loc, notes=[note]) + self.engine.process(diag) + else: + diagnose(valid_forms()) elif types.is_builtin(typ, "print"): valid_forms = lambda: [ valid_form("print(args...) -> None"), diff --git a/artiq/test/lit/inferencer/error_builtin_calls.py b/artiq/test/lit/inferencer/error_builtin_calls.py index aae74c5cd..172f01e8f 100644 --- a/artiq/test/lit/inferencer/error_builtin_calls.py +++ b/artiq/test/lit/inferencer/error_builtin_calls.py @@ -10,3 +10,6 @@ len(1) # CHECK-L: ${LINE:+1}: error: the argument of list() must be of an iterable type list(1) + +# CHECK-L: ${LINE:+1}: error: the arguments of min() must be of a numeric type +min([1], [1]) diff --git a/artiq/test/lit/integration/minmax.py b/artiq/test/lit/integration/minmax.py new file mode 100644 index 000000000..5f325be7b --- /dev/null +++ b/artiq/test/lit/integration/minmax.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +assert min(1, 2) == 1 +assert max(1, 2) == 2 +assert min(1.0, 2.0) == 1.0 +assert max(1.0, 2.0) == 2.0