Make typing.Inferencer idempotent.

This commit is contained in:
whitequark 2015-06-13 09:28:40 +03:00
parent c89bf6fae0
commit 3e2d104014
3 changed files with 40 additions and 23 deletions

View File

@ -29,6 +29,8 @@ class ClassDefT(ast.ClassDef, scoped):
pass pass
class FunctionDefT(ast.FunctionDef, scoped): class FunctionDefT(ast.FunctionDef, scoped):
pass pass
class ModuleT(ast.Module, scoped):
pass
class AttributeT(ast.Attribute, commontyped): class AttributeT(ast.Attribute, commontyped):
pass pass

View File

@ -184,8 +184,11 @@ def is_var(typ):
return isinstance(typ, TVar) return isinstance(typ, TVar)
def is_mono(typ, name, **params): def is_mono(typ, name, **params):
params_match = True
for param in params:
params_match = params_match and typ.params[param] == params[param]
return isinstance(typ, TMono) and \ return isinstance(typ, TMono) and \
typ.name == name and typ.params == params typ.name == name and params_match
def is_numeric(typ): def is_numeric(typ):
return isinstance(typ, TMono) and \ return isinstance(typ, TMono) and \

View File

@ -190,43 +190,35 @@ class Inferencer(algorithm.Transformer):
"name '{name}' is not bound to anything", {"name":name}, loc) "name '{name}' is not bound to anything", {"name":name}, loc)
self.engine.process(diag) self.engine.process(diag)
def visit_root(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node)
self.env_stack.append(extractor.typing_env)
return self.visit(node)
# Visitors that replace node with a typed node # Visitors that replace node with a typed node
# #
def visit_arg(self, node): def visit_Module(self, node):
return asttyped.argT(type=self._find_name(node.arg, node.loc), extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
arg=node.arg, annotation=self.visit(node.annotation), extractor.visit(node)
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
node = asttyped.ModuleT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
body=node.body, loc=node.loc)
return self.visit(node)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node) extractor.visit(node)
self.env_stack.append(extractor.typing_env)
node = asttyped.FunctionDefT( node = asttyped.FunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_, typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
return_type=types.TVar(), return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns, name=node.name, args=node.args, returns=node.returns,
body=node.body, decorator_list=node.decorator_list, body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc, keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc) loc=node.loc)
return self.visit(node)
old_function, self.function = self.function, node def visit_arg(self, node):
self.generic_visit(node) return asttyped.argT(type=self._find_name(node.arg, node.loc),
self.function = old_function arg=node.arg, annotation=self.visit(node.annotation),
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
self.env_stack.pop()
return node
def visit_Num(self, node): def visit_Num(self, node):
if isinstance(node.n, int): if isinstance(node.n, int):
@ -346,6 +338,26 @@ class Inferencer(algorithm.Transformer):
self.engine.process(diag) self.engine.process(diag)
return node return node
def visit_ModuleT(self, node):
self.env_stack.append(node.typing_env)
node = self.generic_visit(node)
self.env_stack.pop()
return node
def visit_FunctionDefT(self, node):
self.env_stack.append(node.typing_env)
old_function, self.function = self.function, node
node = self.generic_visit(node)
self.function = old_function
self.env_stack.pop()
return node
def visit_Assign(self, node): def visit_Assign(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
if len(node.targets) > 1: if len(node.targets) > 1:
@ -455,7 +467,7 @@ def main():
buf = source.Buffer("".join(fileinput.input()), os.path.basename(fileinput.filename())) buf = source.Buffer("".join(fileinput.input()), os.path.basename(fileinput.filename()))
parsed, comments = parse_buffer(buf, engine=engine) parsed, comments = parse_buffer(buf, engine=engine)
typed = Inferencer(engine=engine).visit_root(parsed) typed = Inferencer(engine=engine).visit(parsed)
printer = Printer(buf) printer = Printer(buf)
printer.visit(typed) printer.visit(typed)
for comment in comments: for comment in comments: