Unbreak return type inference.

This commit is contained in:
whitequark 2015-07-04 02:23:55 +03:00
parent 561d403ddd
commit 4358c5c453
1 changed files with 18 additions and 9 deletions

View File

@ -767,14 +767,6 @@ class Inferencer(algorithm.Visitor):
arg.loc, default.loc) arg.loc, default.loc)
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
old_function, self.function = self.function, node
old_in_loop, self.in_loop = self.in_loop, False
old_has_return, self.has_return = self.has_return, False
self.generic_visit(node)
self.function = old_function
self.in_loop = old_in_loop
self.has_return = old_has_return
if any(node.decorator_list): if any(node.decorator_list):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"decorators are not supported", {}, "decorators are not supported", {},
@ -782,13 +774,30 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag) self.engine.process(diag)
return return
old_function, self.function = self.function, node
old_in_loop, self.in_loop = self.in_loop, False
old_has_return, self.has_return = self.has_return, False
self.generic_visit(node)
# Lack of return statements is not the only case where the return # Lack of return statements is not the only case where the return
# type cannot be inferred. The other one is infinite (possibly mutual) # type cannot be inferred. The other one is infinite (possibly mutual)
# recursion. Since Python functions don't have to return a value, # recursion. Since Python functions don't have to return a value,
# we ignore that one. # we ignore that one.
if not self.has_return: if not self.has_return:
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"function with return type {typea}",
{"typea": printer.name(typea)},
node.name_loc),
]
self._unify(node.return_type, builtins.TNone(), self._unify(node.return_type, builtins.TNone(),
node.name_loc, None) node.name_loc, None, makenotes)
self.function = old_function
self.in_loop = old_in_loop
self.has_return = old_has_return
signature_type = self._type_from_arguments(node.args, node.return_type) signature_type = self._type_from_arguments(node.args, node.return_type)
if signature_type: if signature_type: