From e07057c224f929a3f02258427ec9ad8d38be0095 Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 26 Jun 2015 18:53:20 +0300 Subject: [PATCH] Add range types. --- artiq/py2llvm/builtins.py | 27 ++++++- artiq/py2llvm/prelude.py | 2 +- artiq/py2llvm/typing.py | 79 ++++++++++++++----- .../py2llvm/typing/error_builtin_calls.py | 12 +++ lit-test/py2llvm/typing/error_iterable.py | 5 ++ 5 files changed, 100 insertions(+), 25 deletions(-) create mode 100644 lit-test/py2llvm/typing/error_builtin_calls.py create mode 100644 lit-test/py2llvm/typing/error_iterable.py diff --git a/artiq/py2llvm/builtins.py b/artiq/py2llvm/builtins.py index 13ffe745d..8875f64d7 100644 --- a/artiq/py2llvm/builtins.py +++ b/artiq/py2llvm/builtins.py @@ -31,6 +31,12 @@ class TList(types.TMono): elt = types.TVar() super().__init__("list", {"elt": elt}) +class TRange(types.TMono): + def __init__(self, elt=None): + if elt is None: + elt = types.TVar() + super().__init__("range", {"elt": elt}) + def fn_bool(): return types.TBuiltin("class bool") @@ -43,15 +49,15 @@ def fn_float(): def fn_list(): return types.TBuiltin("class list") +def fn_range(): + return types.TBuiltin("function range") + def fn_len(): return types.TBuiltin("function len") def fn_round(): return types.TBuiltin("function round") -def fn_range(): - return types.TBuiltin("function range") - def fn_syscall(): return types.TBuiltin("function syscall") @@ -87,6 +93,21 @@ def is_list(typ, elt=None): else: return types.is_mono(typ, "list") +def is_range(typ, elt=None): + if elt: + return types.is_mono(typ, "range", {"elt": elt}) + else: + return types.is_mono(typ, "range") + +def is_iterable(typ): + typ = typ.find() + return isinstance(typ, types.TMono) and \ + typ.name in ('list', 'range') + +def get_iterable_elt(typ): + if is_iterable(typ): + return typ.find()["elt"] + def is_collection(typ): typ = typ.find() return isinstance(typ, types.TTuple) or \ diff --git a/artiq/py2llvm/prelude.py b/artiq/py2llvm/prelude.py index 3ecd63a69..97a06a082 100644 --- a/artiq/py2llvm/prelude.py +++ b/artiq/py2llvm/prelude.py @@ -11,8 +11,8 @@ def globals(): "int": builtins.fn_int(), "float": builtins.fn_float(), "list": builtins.fn_list(), + "range": builtins.fn_range(), "len": builtins.fn_len(), "round": builtins.fn_round(), - "range": builtins.fn_range(), "syscall": builtins.fn_syscall(), } diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index 926bd0377..27ba7a1ac 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -454,14 +454,22 @@ class Inferencer(algorithm.Visitor): node.attr_loc, [node.value.loc]) self.engine.process(diag) - def _unify_collection(self, element, collection): - # TODO: support more than just lists - self._unify(builtins.TList(element.type), collection.type, - element.loc, collection.loc) + def _unify_iterable(self, element, collection): + if builtins.is_iterable(collection.type): + rhs_type = collection.type.find() + rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type}) + self._unify(rhs_wrapped_lhs_type, rhs_type, + element.loc, collection.loc) + elif not types.is_var(collection.type): + diag = diagnostic.Diagnostic("error", + "type {type} is not iterable", + {"type": types.TypePrinter().name(collection.type)}, + collection.loc, []) + self.engine.process(diag) def visit_SubscriptT(self, node): self.generic_visit(node) - self._unify_collection(element=node, collection=node.value) + self._unify_iterable(element=node, collection=node.value) def visit_IfExpT(self, node): self.generic_visit(node) @@ -678,7 +686,7 @@ class Inferencer(algorithm.Visitor): left.loc, right.loc) elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)): for left, right in pairs: - self._unify_collection(element=left, collection=right) + self._unify_iterable(element=left, collection=right) else: # Eq, NotEq, Lt, LtE, Gt, GtE operands = [node.left] + node.comparators operand_types = [operand.type for operand in operands] @@ -713,7 +721,7 @@ class Inferencer(algorithm.Visitor): def visit_comprehension(self, node): self.generic_visit(node) - self._unify_collection(element=node.target, collection=node.iter) + self._unify_iterable(element=node.target, collection=node.iter) def visit_builtin_call(self, node): typ = node.func.type.find() @@ -770,6 +778,8 @@ class Inferencer(algorithm.Visitor): diag = diagnostic.Diagnostic("error", "the width argument of int() must be an integer literal", {}, node.keywords[0].loc) + self.engine.process(diag) + return self._unify(node.type, builtins.TInt(types.TValue(width.n)), node.loc, None) @@ -805,9 +815,36 @@ class Inferencer(algorithm.Visitor): pass # [] else: diagnose(valid_forms()) + elif builtins.is_builtin(typ, "function range"): + valid_forms = lambda: [ + valid_form("range(max:'a) -> range(elt='a)"), + valid_form("range(min:'a, max:'a) -> range(elt='a)"), + valid_form("range(min:'a, max:'a, step:'a) -> range(elt='a)"), + ] + + range_tvar = types.TVar() + self._unify(node.type, builtins.TRange(range_tvar), + node.loc, None) + + if len(node.args) in (1, 2, 3) and len(node.keywords) == 0: + for arg in node.args: + self._unify(arg.type, range_tvar, + arg.loc, None) + + if not builtins.is_numeric(arg.type): + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(arg.type)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "an argument of range() must be of a numeric type", {}, + node.func.loc, notes=[note]) + self.engine.process(diag) + else: + diagnose(valid_forms()) elif builtins.is_builtin(typ, "function len"): valid_forms = lambda: [ - valid_form("len(x:list(elt='a)) -> int(width='b)"), + valid_form("len(x:'a) -> int(width='b) where 'a is iterable"), ] # TODO: should be ssize_t-sized @@ -817,8 +854,17 @@ class Inferencer(algorithm.Visitor): if len(node.args) == 1 and len(node.keywords) == 0: arg, = node.args - self._unify(arg.type, builtins.TList(), - arg.loc, None) + if builtins.is_list(arg.type) or builtins.is_range(arg.type): + pass + else: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(arg.type)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "the argument of len() must be of an iterable type", {}, + node.func.loc, notes=[note]) + self.engine.process(diag) else: diagnose(valid_forms()) elif builtins.is_builtin(typ, "function round"): @@ -836,13 +882,6 @@ class Inferencer(algorithm.Visitor): arg.loc, None) else: diagnose(valid_forms()) - # TODO: add when there are range types - # elif builtins.is_builtin(typ, "function range"): - # valid_forms = lambda: [ - # valid_form("range(max:'a) -> range(elt='a)"), - # valid_form("range(min:'a, max:'a) -> range(elt='a)"), - # valid_form("range(min:'a, max:'a, step:'a) -> range(elt='a)"), - # ] # TODO: add when it is clear what interface syscall() has # elif builtins.is_builtin(typ, "function syscall"): # valid_Forms = lambda: [ @@ -862,7 +901,7 @@ class Inferencer(algorithm.Visitor): if types.is_var(node.func.type): return # not enough info yet - elif types.is_mono(node.func.type) or types.is_builtin(node.func.type): + elif types.is_builtin(node.func.type): return self.visit_builtin_call(node) elif not types.is_function(node.func.type): diag = diagnostic.Diagnostic("error", @@ -988,9 +1027,7 @@ class Inferencer(algorithm.Visitor): old_in_loop, self.in_loop = self.in_loop, True self.generic_visit(node) self.in_loop = old_in_loop - # TODO: support more than just lists - self._unify(builtins.TList(node.target.type), node.iter.type, - node.target.loc, node.iter.loc) + self._unify_iterable(node.target, node.iter) def visit_While(self, node): old_in_loop, self.in_loop = self.in_loop, True diff --git a/lit-test/py2llvm/typing/error_builtin_calls.py b/lit-test/py2llvm/typing/error_builtin_calls.py new file mode 100644 index 000000000..86cab873d --- /dev/null +++ b/lit-test/py2llvm/typing/error_builtin_calls.py @@ -0,0 +1,12 @@ +# RUN: %python -m artiq.py2llvm.typing +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +a = 1 +# CHECK-L: ${LINE:+1}: error: the width argument of int() must be an integer literal +int(1.0, width=a) + +# CHECK-L: ${LINE:+1}: error: the argument of len() must be of an iterable type +len(1) + +# CHECK-L: ${LINE:+1}: error: an argument of range() must be of a numeric type +range([]) diff --git a/lit-test/py2llvm/typing/error_iterable.py b/lit-test/py2llvm/typing/error_iterable.py new file mode 100644 index 000000000..63382614f --- /dev/null +++ b/lit-test/py2llvm/typing/error_iterable.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.py2llvm.typing +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: type int(width='a) is not iterable +for x in 1: pass