diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index da793f6d3..e93af3225 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -130,6 +130,7 @@ class Inferencer(algorithm.Transformer): def __init__(self, engine): self.engine = engine self.env_stack = [{}] + self.function = None # currently visited function def _unify(self, typea, typeb, loca, locb, kind): try: @@ -137,21 +138,32 @@ class Inferencer(algorithm.Transformer): except types.UnificationError as e: printer = types.TypePrinter() - if kind == "generic": - note1 = diagnostic.Diagnostic("note", - "expression of type {typea}", - {"typea": printer.name(typea)}, - loca) - elif kind == "expects": + if kind == "expects": note1 = diagnostic.Diagnostic("note", "expression expecting an operand of type {typea}", {"typea": printer.name(typea)}, loca) + elif kind == "return_type" or kind == "return_type_none": + note1 = diagnostic.Diagnostic("note", + "function with return type {typea}", + {"typea": printer.name(typea)}, + loca) + else: + note1 = diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca) - note2 = diagnostic.Diagnostic("note", - "expression of type {typeb}", - {"typeb": printer.name(typeb)}, - locb) + if kind == "return_type_none": + note2 = diagnostic.Diagnostic("note", + "implied expression of type {typeb}", + {"typeb": printer.name(typeb)}, + locb) + else: + note2 = diagnostic.Diagnostic("note", + "expression of type {typeb}", + {"typeb": printer.name(typeb)}, + locb) if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): diag = diagnostic.Diagnostic("fatal", @@ -166,23 +178,6 @@ class Inferencer(algorithm.Transformer): loca, [locb], notes=[note1, note2]) self.engine.process(diag) - def visit_FunctionDef(self, node): - extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) - extractor.visit(node) - - self.env_stack.append(extractor.typing_env) - node = asttyped.FunctionDefT( - typing_env=extractor.typing_env, globals_in_scope=extractor.global_, - name=node.name, args=self.visit(node.args), returns=self.visit(node.returns), - body=[self.visit(x) for x in node.body], decorator_list=node.decorator_list, - keyword_loc=node.keyword_loc, name_loc=node.name_loc, - arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, - loc=node.loc) - self.generic_visit(node) - self.env_stack.pop() - - return node - def _find_name(self, name, loc): for typing_env in reversed(self.env_stack): if name in typing_env: @@ -198,6 +193,39 @@ class Inferencer(algorithm.Transformer): arg=node.arg, annotation=self.visit(node.annotation), arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) + def visit_FunctionDef(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + self.env_stack.append(extractor.typing_env) + + node = asttyped.FunctionDefT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + return_type=types.TVar(), + + name=node.name, args=node.args, returns=node.returns, + body=node.body, decorator_list=node.decorator_list, + keyword_loc=node.keyword_loc, name_loc=node.name_loc, + arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, + loc=node.loc) + + old_function, self.function = self.function, node + self.generic_visit(node) + self.function = old_function + + self.env_stack.pop() + + return node + + def visit_Return(self, node): + node = self.generic_visit(node) + if node.value is None: + self._unify(self.function.return_type, types.TNone(), + self.function.name_loc, node.value.loc, kind="return_type_none") + else: + self._unify(self.function.return_type, node.value.type, + self.function.name_loc, node.value.loc, kind="return_type") + def visit_Num(self, node): if isinstance(node.n, int): typ = types.TInt() @@ -307,6 +335,12 @@ class Printer(algorithm.Visitor): def rewrite(self): return self.rewriter.rewrite() + def visit_FunctionDefT(self, node): + self.rewriter.insert_before(node.colon_loc, + "->{}".format(self.type_printer.name(node.return_type))) + + super().generic_visit(node) + def generic_visit(self, node): if hasattr(node, "type"): self.rewriter.insert_after(node.loc,