From 4c95647162ee119a87ab7181a062848d5e88a4b8 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 13 Jun 2015 11:03:33 +0300 Subject: [PATCH] Split ASTTypedRewriter off Inferencer. --- artiq/py2llvm/typing.py | 209 +++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 108 deletions(-) diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index 70201a04f..2d15ae80c 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -126,61 +126,10 @@ class LocalExtractor(algorithm.Visitor): self.visit(stmt) -class Inferencer(algorithm.Transformer): +class ASTTypedRewriter(algorithm.Transformer): def __init__(self, engine): self.engine = engine self.env_stack = [] - self.function = None # currently visited function - - def _unify(self, typea, typeb, loca, locb, makenotes=None): - try: - typea.unify(typeb) - except types.UnificationError as e: - printer = types.TypePrinter() - - if makenotes: - notes = makenotes(printer, typea, typeb, loca, locb) - else: - notes = [ - diagnostic.Diagnostic("note", - "expression of type {typea}", - {"typea": printer.name(typea)}, - loca), - diagnostic.Diagnostic("note", - "expression of type {typeb}", - {"typeb": printer.name(typeb)}, - locb) - ] - - highlights = [locb] if locb else [] - if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): - diag = diagnostic.Diagnostic("error", - "cannot unify {typea} with {typeb}", - {"typea": printer.name(typea), "typeb": printer.name(typeb)}, - loca, highlights, notes) - else: # give more detail - diag = diagnostic.Diagnostic("error", - "cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}", - {"typea": printer.name(typea), "typeb": printer.name(typeb), - "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)}, - loca, highlights, notes) - self.engine.process(diag) - - # makenotes for the case where types of multiple elements are unified - # with the type of parent expression - def _makenotes_elts(self, elts, kind): - def makenotes(printer, typea, typeb, loca, locb): - return [ - diagnostic.Diagnostic("note", - "{kind} of type {typea}", - {"kind": kind, "typea": printer.name(elts[0].type)}, - elts[0].loc), - diagnostic.Diagnostic("note", - "{kind} of type {typeb}", - {"kind": kind, "typeb": printer.name(typeb)}, - locb) - ] - return makenotes def _find_name(self, name, loc): for typing_env in reversed(self.env_stack): @@ -199,7 +148,12 @@ class Inferencer(algorithm.Transformer): node = asttyped.ModuleT( typing_env=extractor.typing_env, globals_in_scope=extractor.global_, body=node.body, loc=node.loc) - return self.visit(node) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() def visit_FunctionDef(self, node): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) @@ -213,7 +167,12 @@ class Inferencer(algorithm.Transformer): keyword_loc=node.keyword_loc, name_loc=node.name_loc, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, loc=node.loc) - return self.visit(node) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() def visit_arg(self, node): return asttyped.argT(type=self._find_name(node.arg, node.loc), @@ -297,31 +256,107 @@ class Inferencer(algorithm.Transformer): loc=node.loc) return self.visit(node) - # Visitors that just unify types + # Unsupported visitors # + def visit_unsupported(self, node): + diag = diagnostic.Diagnostic("fatal", + "this syntax is not supported", {}, + node.loc) + self.engine.process(diag) + + # expr + visit_Attribute = visit_unsupported + visit_BinOp = visit_unsupported + visit_Call = visit_unsupported + visit_Compare = visit_unsupported + visit_Dict = visit_unsupported + visit_DictComp = visit_unsupported + visit_Ellipsis = visit_unsupported + visit_GeneratorExp = visit_unsupported + visit_Lambda = visit_unsupported + visit_ListComp = visit_unsupported + visit_Set = visit_unsupported + visit_SetComp = visit_unsupported + visit_Str = visit_unsupported + visit_Starred = visit_unsupported + visit_Yield = visit_unsupported + visit_YieldFrom = visit_unsupported + +class Inferencer(algorithm.Visitor): + def __init__(self, engine): + self.engine = engine + # currently visited function, for Return inference + self.function = None + + def _unify(self, typea, typeb, loca, locb, makenotes=None): + try: + typea.unify(typeb) + except types.UnificationError as e: + printer = types.TypePrinter() + + if makenotes: + notes = makenotes(printer, typea, typeb, loca, locb) + else: + notes = [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "expression of type {typeb}", + {"typeb": printer.name(typeb)}, + locb) + ] + + highlights = [locb] if locb else [] + if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): + diag = diagnostic.Diagnostic("error", + "cannot unify {typea} with {typeb}", + {"typea": printer.name(typea), "typeb": printer.name(typeb)}, + loca, highlights, notes) + else: # give more detail + diag = diagnostic.Diagnostic("error", + "cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}", + {"typea": printer.name(typea), "typeb": printer.name(typeb), + "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)}, + loca, highlights, notes) + self.engine.process(diag) + + # makenotes for the case where types of multiple elements are unified + # with the type of parent expression + def _makenotes_elts(self, elts, kind): + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "{kind} of type {typea}", + {"kind": kind, "typea": printer.name(elts[0].type)}, + elts[0].loc), + diagnostic.Diagnostic("note", + "{kind} of type {typeb}", + {"kind": kind, "typeb": printer.name(typeb)}, + locb) + ] + return makenotes + def visit_ListT(self, node): for elt in node.elts: self._unify(node.type["elt"], elt.type, node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element")) - return node def visit_SubscriptT(self, node): # TODO: support more than just lists self._unify(builtins.TList(node.type), node.value.type, node.loc, node.value.loc) - return node def visit_IfExpT(self, node): self._unify(node.body.type, node.orelse.type, node.body.loc, node.orelse.loc) node.type = node.body.type - return node def visit_BoolOpT(self, node): for value in node.values: self._unify(node.type, value.type, node.loc, value.loc, self._makenotes_elts(node.values, "an operand")) - return node def visit_UnaryOpT(self, node): if isinstance(node.op, ast.Not): @@ -336,53 +371,34 @@ class Inferencer(algorithm.Transformer): {"type": types.TypePrinter().name(operand_type)}, node.operand.loc) self.engine.process(diag) - return node - - def visit_ModuleT(self, node): - self.env_stack.append(node.typing_env) - - node = self.generic_visit(node) - - self.env_stack.pop() - - return node def visit_FunctionDefT(self, node): - self.env_stack.append(node.typing_env) old_function, self.function = self.function, node - node = self.generic_visit(node) - self.function = old_function - self.env_stack.pop() - - return node def visit_Assign(self, node): - node = self.generic_visit(node) + self.generic_visit(node) if len(node.targets) > 1: self._unify(builtins.TTuple([x.type for x in node.targets]), node.value.type, node.targets[0].loc.join(node.targets[-1].loc), node.value.loc) else: self._unify(node.targets[0].type, node.value.type, node.targets[0].loc, node.value.loc) - return node def visit_AugAssign(self, node): - node = self.generic_visit(node) + self.generic_visit(node) self._unify(node.target.type, node.value.type, node.target.loc, node.value.loc) - return node def visit_For(self, node): - node = self.generic_visit(node) + self.generic_visit(node) # TODO: support more than just lists self._unify(builtins.TList(node.target.type), node.iter.type, node.target.loc, node.iter.loc) - return node def visit_Return(self, node): - node = self.generic_visit(node) + self.generic_visit(node) def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", @@ -401,31 +417,6 @@ class Inferencer(algorithm.Transformer): self._unify(self.function.return_type, node.value.type, self.function.name_loc, node.value.loc, makenotes) - # Unsupported visitors - # - def visit_unsupported(self, node): - diag = diagnostic.Diagnostic("fatal", - "this syntax is not supported", {}, - node.loc) - self.engine.process(diag) - - visit_Attribute = visit_unsupported - visit_BinOp = visit_unsupported - visit_Call = visit_unsupported - visit_Compare = visit_unsupported - visit_Dict = visit_unsupported - visit_DictComp = visit_unsupported - visit_Ellipsis = visit_unsupported - visit_GeneratorExp = visit_unsupported - visit_Lambda = visit_unsupported - visit_ListComp = visit_unsupported - visit_Set = visit_unsupported - visit_SetComp = visit_unsupported - visit_Str = visit_unsupported - visit_Starred = visit_unsupported - visit_Yield = visit_unsupported - visit_YieldFrom = visit_unsupported - class Printer(algorithm.Visitor): def __init__(self, buf): self.rewriter = source.Rewriter(buf) @@ -467,7 +458,9 @@ def main(): buf = source.Buffer("".join(fileinput.input()), os.path.basename(fileinput.filename())) parsed, comments = parse_buffer(buf, engine=engine) - typed = Inferencer(engine=engine).visit(parsed) + typed = ASTTypedRewriter(engine=engine).visit(parsed) + Inferencer(engine=engine).visit(typed) + printer = Printer(buf) printer.visit(typed) for comment in comments: