forked from M-Labs/artiq
Add support for Return.
This commit is contained in:
parent
d08598fa0f
commit
5f06c6af10
|
@ -130,6 +130,7 @@ class Inferencer(algorithm.Transformer):
|
||||||
def __init__(self, engine):
|
def __init__(self, engine):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.env_stack = [{}]
|
self.env_stack = [{}]
|
||||||
|
self.function = None # currently visited function
|
||||||
|
|
||||||
def _unify(self, typea, typeb, loca, locb, kind):
|
def _unify(self, typea, typeb, loca, locb, kind):
|
||||||
try:
|
try:
|
||||||
|
@ -137,21 +138,32 @@ class Inferencer(algorithm.Transformer):
|
||||||
except types.UnificationError as e:
|
except types.UnificationError as e:
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
|
|
||||||
if kind == "generic":
|
if kind == "expects":
|
||||||
note1 = diagnostic.Diagnostic("note",
|
|
||||||
"expression of type {typea}",
|
|
||||||
{"typea": printer.name(typea)},
|
|
||||||
loca)
|
|
||||||
elif kind == "expects":
|
|
||||||
note1 = diagnostic.Diagnostic("note",
|
note1 = diagnostic.Diagnostic("note",
|
||||||
"expression expecting an operand of type {typea}",
|
"expression expecting an operand of type {typea}",
|
||||||
{"typea": printer.name(typea)},
|
{"typea": printer.name(typea)},
|
||||||
loca)
|
loca)
|
||||||
|
elif kind == "return_type" or kind == "return_type_none":
|
||||||
|
note1 = diagnostic.Diagnostic("note",
|
||||||
|
"function with return type {typea}",
|
||||||
|
{"typea": printer.name(typea)},
|
||||||
|
loca)
|
||||||
|
else:
|
||||||
|
note1 = diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typea}",
|
||||||
|
{"typea": printer.name(typea)},
|
||||||
|
loca)
|
||||||
|
|
||||||
note2 = diagnostic.Diagnostic("note",
|
if kind == "return_type_none":
|
||||||
"expression of type {typeb}",
|
note2 = diagnostic.Diagnostic("note",
|
||||||
{"typeb": printer.name(typeb)},
|
"implied expression of type {typeb}",
|
||||||
locb)
|
{"typeb": printer.name(typeb)},
|
||||||
|
locb)
|
||||||
|
else:
|
||||||
|
note2 = diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typeb}",
|
||||||
|
{"typeb": printer.name(typeb)},
|
||||||
|
locb)
|
||||||
|
|
||||||
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
|
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
|
@ -166,23 +178,6 @@ class Inferencer(algorithm.Transformer):
|
||||||
loca, [locb], notes=[note1, note2])
|
loca, [locb], notes=[note1, note2])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
|
||||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
|
||||||
extractor.visit(node)
|
|
||||||
|
|
||||||
self.env_stack.append(extractor.typing_env)
|
|
||||||
node = asttyped.FunctionDefT(
|
|
||||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
|
||||||
name=node.name, args=self.visit(node.args), returns=self.visit(node.returns),
|
|
||||||
body=[self.visit(x) for x in node.body], decorator_list=node.decorator_list,
|
|
||||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
|
||||||
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
|
||||||
loc=node.loc)
|
|
||||||
self.generic_visit(node)
|
|
||||||
self.env_stack.pop()
|
|
||||||
|
|
||||||
return node
|
|
||||||
|
|
||||||
def _find_name(self, name, loc):
|
def _find_name(self, name, loc):
|
||||||
for typing_env in reversed(self.env_stack):
|
for typing_env in reversed(self.env_stack):
|
||||||
if name in typing_env:
|
if name in typing_env:
|
||||||
|
@ -198,6 +193,39 @@ class Inferencer(algorithm.Transformer):
|
||||||
arg=node.arg, annotation=self.visit(node.annotation),
|
arg=node.arg, annotation=self.visit(node.annotation),
|
||||||
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||||
|
extractor.visit(node)
|
||||||
|
|
||||||
|
self.env_stack.append(extractor.typing_env)
|
||||||
|
|
||||||
|
node = asttyped.FunctionDefT(
|
||||||
|
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||||
|
return_type=types.TVar(),
|
||||||
|
|
||||||
|
name=node.name, args=node.args, returns=node.returns,
|
||||||
|
body=node.body, decorator_list=node.decorator_list,
|
||||||
|
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||||
|
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||||
|
loc=node.loc)
|
||||||
|
|
||||||
|
old_function, self.function = self.function, node
|
||||||
|
self.generic_visit(node)
|
||||||
|
self.function = old_function
|
||||||
|
|
||||||
|
self.env_stack.pop()
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Return(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
if node.value is None:
|
||||||
|
self._unify(self.function.return_type, types.TNone(),
|
||||||
|
self.function.name_loc, node.value.loc, kind="return_type_none")
|
||||||
|
else:
|
||||||
|
self._unify(self.function.return_type, node.value.type,
|
||||||
|
self.function.name_loc, node.value.loc, kind="return_type")
|
||||||
|
|
||||||
def visit_Num(self, node):
|
def visit_Num(self, node):
|
||||||
if isinstance(node.n, int):
|
if isinstance(node.n, int):
|
||||||
typ = types.TInt()
|
typ = types.TInt()
|
||||||
|
@ -307,6 +335,12 @@ class Printer(algorithm.Visitor):
|
||||||
def rewrite(self):
|
def rewrite(self):
|
||||||
return self.rewriter.rewrite()
|
return self.rewriter.rewrite()
|
||||||
|
|
||||||
|
def visit_FunctionDefT(self, node):
|
||||||
|
self.rewriter.insert_before(node.colon_loc,
|
||||||
|
"->{}".format(self.type_printer.name(node.return_type)))
|
||||||
|
|
||||||
|
super().generic_visit(node)
|
||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node):
|
||||||
if hasattr(node, "type"):
|
if hasattr(node, "type"):
|
||||||
self.rewriter.insert_after(node.loc,
|
self.rewriter.insert_after(node.loc,
|
||||||
|
|
Loading…
Reference in New Issue