diff --git a/artiq/py2llvm/builtins.py b/artiq/py2llvm/builtins.py index 569b85cb2..48621ce10 100644 --- a/artiq/py2llvm/builtins.py +++ b/artiq/py2llvm/builtins.py @@ -44,7 +44,7 @@ def is_int(typ, width=None): def get_int_width(typ): if is_int(typ): - return types.get_value(typ["width"]) + return types.get_value(typ.find()["width"]) def is_float(typ): return types.is_mono(typ, "float") diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index 5e2c498fc..5d2b140a9 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -273,6 +273,22 @@ class ASTTypedRewriter(algorithm.Transformer): if_loc=node.if_loc, else_loc=node.else_loc, loc=node.loc) return self.visit(node) + def visit_ListComp(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.ListCompT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + type=types.TVar(), + elt=node.elt, generators=node.generators, + begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + def visit_Raise(self, node): node = self.generic_visit(node) if node.cause: @@ -297,7 +313,6 @@ class ASTTypedRewriter(algorithm.Transformer): visit_Ellipsis = visit_unsupported visit_GeneratorExp = visit_unsupported visit_Lambda = visit_unsupported - visit_ListComp = visit_unsupported visit_Set = visit_unsupported visit_SetComp = visit_unsupported visit_Str = visit_unsupported @@ -646,6 +661,14 @@ class Inferencer(algorithm.Visitor): [self._coerce_one(typ, operand, other_node) for operand in operands] node.type.unify(builtins.TBool()) + def visit_ListCompT(self, node): + self.generic_visit(node) + node.type.unify(builtins.TList(node.elt.type)) # should never fail + + def visit_comprehension(self, node): + self.generic_visit(node) + self._unify_collection(element=node.target, collection=node.iter) + def visit_Assign(self, node): self.generic_visit(node) if len(node.targets) > 1: diff --git a/lit-test/py2llvm/typing/unify.py b/lit-test/py2llvm/typing/unify.py index 64dfb20b7..94207cc1f 100644 --- a/lit-test/py2llvm/typing/unify.py +++ b/lit-test/py2llvm/typing/unify.py @@ -47,3 +47,6 @@ True and False not 1 # CHECK-L: 1:int(width='i):bool + +[x for x in [1]] +# CHECK-L: [x:int(width='j) for x:int(width='j) in [1:int(width='j)]:list(elt=int(width='j))]:list(elt=int(width='j))