Add support for Return.

This commit is contained in:
whitequark 2015-06-06 15:12:56 +03:00
parent d08598fa0f
commit 5f06c6af10
1 changed files with 61 additions and 27 deletions

View File

@ -130,6 +130,7 @@ class Inferencer(algorithm.Transformer):
def __init__(self, engine): def __init__(self, engine):
self.engine = engine self.engine = engine
self.env_stack = [{}] self.env_stack = [{}]
self.function = None # currently visited function
def _unify(self, typea, typeb, loca, locb, kind): def _unify(self, typea, typeb, loca, locb, kind):
try: try:
@ -137,21 +138,32 @@ class Inferencer(algorithm.Transformer):
except types.UnificationError as e: except types.UnificationError as e:
printer = types.TypePrinter() printer = types.TypePrinter()
if kind == "generic": if kind == "expects":
note1 = diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca)
elif kind == "expects":
note1 = diagnostic.Diagnostic("note", note1 = diagnostic.Diagnostic("note",
"expression expecting an operand of type {typea}", "expression expecting an operand of type {typea}",
{"typea": printer.name(typea)}, {"typea": printer.name(typea)},
loca) loca)
elif kind == "return_type" or kind == "return_type_none":
note1 = diagnostic.Diagnostic("note",
"function with return type {typea}",
{"typea": printer.name(typea)},
loca)
else:
note1 = diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca)
note2 = diagnostic.Diagnostic("note", if kind == "return_type_none":
"expression of type {typeb}", note2 = diagnostic.Diagnostic("note",
{"typeb": printer.name(typeb)}, "implied expression of type {typeb}",
locb) {"typeb": printer.name(typeb)},
locb)
else:
note2 = diagnostic.Diagnostic("note",
"expression of type {typeb}",
{"typeb": printer.name(typeb)},
locb)
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
@ -166,23 +178,6 @@ class Inferencer(algorithm.Transformer):
loca, [locb], notes=[note1, note2]) loca, [locb], notes=[note1, note2])
self.engine.process(diag) self.engine.process(diag)
def visit_FunctionDef(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node)
self.env_stack.append(extractor.typing_env)
node = asttyped.FunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
name=node.name, args=self.visit(node.args), returns=self.visit(node.returns),
body=[self.visit(x) for x in node.body], decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc)
self.generic_visit(node)
self.env_stack.pop()
return node
def _find_name(self, name, loc): def _find_name(self, name, loc):
for typing_env in reversed(self.env_stack): for typing_env in reversed(self.env_stack):
if name in typing_env: if name in typing_env:
@ -198,6 +193,39 @@ class Inferencer(algorithm.Transformer):
arg=node.arg, annotation=self.visit(node.annotation), arg=node.arg, annotation=self.visit(node.annotation),
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
def visit_FunctionDef(self, node):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node)
self.env_stack.append(extractor.typing_env)
node = asttyped.FunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns,
body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc)
old_function, self.function = self.function, node
self.generic_visit(node)
self.function = old_function
self.env_stack.pop()
return node
def visit_Return(self, node):
node = self.generic_visit(node)
if node.value is None:
self._unify(self.function.return_type, types.TNone(),
self.function.name_loc, node.value.loc, kind="return_type_none")
else:
self._unify(self.function.return_type, node.value.type,
self.function.name_loc, node.value.loc, kind="return_type")
def visit_Num(self, node): def visit_Num(self, node):
if isinstance(node.n, int): if isinstance(node.n, int):
typ = types.TInt() typ = types.TInt()
@ -307,6 +335,12 @@ class Printer(algorithm.Visitor):
def rewrite(self): def rewrite(self):
return self.rewriter.rewrite() return self.rewriter.rewrite()
def visit_FunctionDefT(self, node):
self.rewriter.insert_before(node.colon_loc,
"->{}".format(self.type_printer.name(node.return_type)))
super().generic_visit(node)
def generic_visit(self, node): def generic_visit(self, node):
if hasattr(node, "type"): if hasattr(node, "type"):
self.rewriter.insert_after(node.loc, self.rewriter.insert_after(node.loc,