From 3e2d104014a9276bcc5611ba70d44dd4db655c69 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 13 Jun 2015 09:28:40 +0300 Subject: [PATCH] Make typing.Inferencer idempotent. --- artiq/py2llvm/asttyped.py | 2 ++ artiq/py2llvm/types.py | 5 +++- artiq/py2llvm/typing.py | 56 ++++++++++++++++++++++++--------------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/artiq/py2llvm/asttyped.py b/artiq/py2llvm/asttyped.py index 848d2625a..415f7dcc0 100644 --- a/artiq/py2llvm/asttyped.py +++ b/artiq/py2llvm/asttyped.py @@ -29,6 +29,8 @@ class ClassDefT(ast.ClassDef, scoped): pass class FunctionDefT(ast.FunctionDef, scoped): pass +class ModuleT(ast.Module, scoped): + pass class AttributeT(ast.Attribute, commontyped): pass diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py index 60768a349..f56aa34a8 100644 --- a/artiq/py2llvm/types.py +++ b/artiq/py2llvm/types.py @@ -184,8 +184,11 @@ def is_var(typ): return isinstance(typ, TVar) def is_mono(typ, name, **params): + params_match = True + for param in params: + params_match = params_match and typ.params[param] == params[param] return isinstance(typ, TMono) and \ - typ.name == name and typ.params == params + typ.name == name and params_match def is_numeric(typ): return isinstance(typ, TMono) and \ diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index 943b5b6fe..aa11f0e53 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -190,43 +190,35 @@ class Inferencer(algorithm.Transformer): "name '{name}' is not bound to anything", {"name":name}, loc) self.engine.process(diag) - def visit_root(self, node): - extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) - extractor.visit(node) - self.env_stack.append(extractor.typing_env) - - return self.visit(node) - # Visitors that replace node with a typed node # - def visit_arg(self, node): - return asttyped.argT(type=self._find_name(node.arg, node.loc), - arg=node.arg, annotation=self.visit(node.annotation), - arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) + def visit_Module(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.ModuleT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + body=node.body, loc=node.loc) + return self.visit(node) def visit_FunctionDef(self, node): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor.visit(node) - self.env_stack.append(extractor.typing_env) - node = asttyped.FunctionDefT( typing_env=extractor.typing_env, globals_in_scope=extractor.global_, return_type=types.TVar(), - name=node.name, args=node.args, returns=node.returns, body=node.body, decorator_list=node.decorator_list, keyword_loc=node.keyword_loc, name_loc=node.name_loc, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, loc=node.loc) + return self.visit(node) - old_function, self.function = self.function, node - self.generic_visit(node) - self.function = old_function - - self.env_stack.pop() - - return node + def visit_arg(self, node): + return asttyped.argT(type=self._find_name(node.arg, node.loc), + arg=node.arg, annotation=self.visit(node.annotation), + arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) def visit_Num(self, node): if isinstance(node.n, int): @@ -346,6 +338,26 @@ class Inferencer(algorithm.Transformer): self.engine.process(diag) return node + def visit_ModuleT(self, node): + self.env_stack.append(node.typing_env) + + node = self.generic_visit(node) + + self.env_stack.pop() + + return node + + def visit_FunctionDefT(self, node): + self.env_stack.append(node.typing_env) + old_function, self.function = self.function, node + + node = self.generic_visit(node) + + self.function = old_function + self.env_stack.pop() + + return node + def visit_Assign(self, node): node = self.generic_visit(node) if len(node.targets) > 1: @@ -455,7 +467,7 @@ def main(): buf = source.Buffer("".join(fileinput.input()), os.path.basename(fileinput.filename())) parsed, comments = parse_buffer(buf, engine=engine) - typed = Inferencer(engine=engine).visit_root(parsed) + typed = Inferencer(engine=engine).visit(parsed) printer = Printer(buf) printer.visit(typed) for comment in comments: