forked from M-Labs/artiq
Split ASTTypedRewriter off Inferencer.
This commit is contained in:
parent
61434a8da3
commit
4c95647162
|
@ -126,61 +126,10 @@ class LocalExtractor(algorithm.Visitor):
|
||||||
self.visit(stmt)
|
self.visit(stmt)
|
||||||
|
|
||||||
|
|
||||||
class Inferencer(algorithm.Transformer):
|
class ASTTypedRewriter(algorithm.Transformer):
|
||||||
def __init__(self, engine):
|
def __init__(self, engine):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.env_stack = []
|
self.env_stack = []
|
||||||
self.function = None # currently visited function
|
|
||||||
|
|
||||||
def _unify(self, typea, typeb, loca, locb, makenotes=None):
|
|
||||||
try:
|
|
||||||
typea.unify(typeb)
|
|
||||||
except types.UnificationError as e:
|
|
||||||
printer = types.TypePrinter()
|
|
||||||
|
|
||||||
if makenotes:
|
|
||||||
notes = makenotes(printer, typea, typeb, loca, locb)
|
|
||||||
else:
|
|
||||||
notes = [
|
|
||||||
diagnostic.Diagnostic("note",
|
|
||||||
"expression of type {typea}",
|
|
||||||
{"typea": printer.name(typea)},
|
|
||||||
loca),
|
|
||||||
diagnostic.Diagnostic("note",
|
|
||||||
"expression of type {typeb}",
|
|
||||||
{"typeb": printer.name(typeb)},
|
|
||||||
locb)
|
|
||||||
]
|
|
||||||
|
|
||||||
highlights = [locb] if locb else []
|
|
||||||
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"cannot unify {typea} with {typeb}",
|
|
||||||
{"typea": printer.name(typea), "typeb": printer.name(typeb)},
|
|
||||||
loca, highlights, notes)
|
|
||||||
else: # give more detail
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}",
|
|
||||||
{"typea": printer.name(typea), "typeb": printer.name(typeb),
|
|
||||||
"fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)},
|
|
||||||
loca, highlights, notes)
|
|
||||||
self.engine.process(diag)
|
|
||||||
|
|
||||||
# makenotes for the case where types of multiple elements are unified
|
|
||||||
# with the type of parent expression
|
|
||||||
def _makenotes_elts(self, elts, kind):
|
|
||||||
def makenotes(printer, typea, typeb, loca, locb):
|
|
||||||
return [
|
|
||||||
diagnostic.Diagnostic("note",
|
|
||||||
"{kind} of type {typea}",
|
|
||||||
{"kind": kind, "typea": printer.name(elts[0].type)},
|
|
||||||
elts[0].loc),
|
|
||||||
diagnostic.Diagnostic("note",
|
|
||||||
"{kind} of type {typeb}",
|
|
||||||
{"kind": kind, "typeb": printer.name(typeb)},
|
|
||||||
locb)
|
|
||||||
]
|
|
||||||
return makenotes
|
|
||||||
|
|
||||||
def _find_name(self, name, loc):
|
def _find_name(self, name, loc):
|
||||||
for typing_env in reversed(self.env_stack):
|
for typing_env in reversed(self.env_stack):
|
||||||
|
@ -199,7 +148,12 @@ class Inferencer(algorithm.Transformer):
|
||||||
node = asttyped.ModuleT(
|
node = asttyped.ModuleT(
|
||||||
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||||
body=node.body, loc=node.loc)
|
body=node.body, loc=node.loc)
|
||||||
return self.visit(node)
|
|
||||||
|
try:
|
||||||
|
self.env_stack.append(node.typing_env)
|
||||||
|
return self.generic_visit(node)
|
||||||
|
finally:
|
||||||
|
self.env_stack.pop()
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
|
||||||
|
@ -213,7 +167,12 @@ class Inferencer(algorithm.Transformer):
|
||||||
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||||
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||||
loc=node.loc)
|
loc=node.loc)
|
||||||
return self.visit(node)
|
|
||||||
|
try:
|
||||||
|
self.env_stack.append(node.typing_env)
|
||||||
|
return self.generic_visit(node)
|
||||||
|
finally:
|
||||||
|
self.env_stack.pop()
|
||||||
|
|
||||||
def visit_arg(self, node):
|
def visit_arg(self, node):
|
||||||
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
||||||
|
@ -297,31 +256,107 @@ class Inferencer(algorithm.Transformer):
|
||||||
loc=node.loc)
|
loc=node.loc)
|
||||||
return self.visit(node)
|
return self.visit(node)
|
||||||
|
|
||||||
# Visitors that just unify types
|
# Unsupported visitors
|
||||||
#
|
#
|
||||||
|
def visit_unsupported(self, node):
|
||||||
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
|
"this syntax is not supported", {},
|
||||||
|
node.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
# expr
|
||||||
|
visit_Attribute = visit_unsupported
|
||||||
|
visit_BinOp = visit_unsupported
|
||||||
|
visit_Call = visit_unsupported
|
||||||
|
visit_Compare = visit_unsupported
|
||||||
|
visit_Dict = visit_unsupported
|
||||||
|
visit_DictComp = visit_unsupported
|
||||||
|
visit_Ellipsis = visit_unsupported
|
||||||
|
visit_GeneratorExp = visit_unsupported
|
||||||
|
visit_Lambda = visit_unsupported
|
||||||
|
visit_ListComp = visit_unsupported
|
||||||
|
visit_Set = visit_unsupported
|
||||||
|
visit_SetComp = visit_unsupported
|
||||||
|
visit_Str = visit_unsupported
|
||||||
|
visit_Starred = visit_unsupported
|
||||||
|
visit_Yield = visit_unsupported
|
||||||
|
visit_YieldFrom = visit_unsupported
|
||||||
|
|
||||||
|
class Inferencer(algorithm.Visitor):
|
||||||
|
def __init__(self, engine):
|
||||||
|
self.engine = engine
|
||||||
|
# currently visited function, for Return inference
|
||||||
|
self.function = None
|
||||||
|
|
||||||
|
def _unify(self, typea, typeb, loca, locb, makenotes=None):
|
||||||
|
try:
|
||||||
|
typea.unify(typeb)
|
||||||
|
except types.UnificationError as e:
|
||||||
|
printer = types.TypePrinter()
|
||||||
|
|
||||||
|
if makenotes:
|
||||||
|
notes = makenotes(printer, typea, typeb, loca, locb)
|
||||||
|
else:
|
||||||
|
notes = [
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typea}",
|
||||||
|
{"typea": printer.name(typea)},
|
||||||
|
loca),
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typeb}",
|
||||||
|
{"typeb": printer.name(typeb)},
|
||||||
|
locb)
|
||||||
|
]
|
||||||
|
|
||||||
|
highlights = [locb] if locb else []
|
||||||
|
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"cannot unify {typea} with {typeb}",
|
||||||
|
{"typea": printer.name(typea), "typeb": printer.name(typeb)},
|
||||||
|
loca, highlights, notes)
|
||||||
|
else: # give more detail
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}",
|
||||||
|
{"typea": printer.name(typea), "typeb": printer.name(typeb),
|
||||||
|
"fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)},
|
||||||
|
loca, highlights, notes)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
# makenotes for the case where types of multiple elements are unified
|
||||||
|
# with the type of parent expression
|
||||||
|
def _makenotes_elts(self, elts, kind):
|
||||||
|
def makenotes(printer, typea, typeb, loca, locb):
|
||||||
|
return [
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"{kind} of type {typea}",
|
||||||
|
{"kind": kind, "typea": printer.name(elts[0].type)},
|
||||||
|
elts[0].loc),
|
||||||
|
diagnostic.Diagnostic("note",
|
||||||
|
"{kind} of type {typeb}",
|
||||||
|
{"kind": kind, "typeb": printer.name(typeb)},
|
||||||
|
locb)
|
||||||
|
]
|
||||||
|
return makenotes
|
||||||
|
|
||||||
def visit_ListT(self, node):
|
def visit_ListT(self, node):
|
||||||
for elt in node.elts:
|
for elt in node.elts:
|
||||||
self._unify(node.type["elt"], elt.type,
|
self._unify(node.type["elt"], elt.type,
|
||||||
node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element"))
|
node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element"))
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_SubscriptT(self, node):
|
def visit_SubscriptT(self, node):
|
||||||
# TODO: support more than just lists
|
# TODO: support more than just lists
|
||||||
self._unify(builtins.TList(node.type), node.value.type,
|
self._unify(builtins.TList(node.type), node.value.type,
|
||||||
node.loc, node.value.loc)
|
node.loc, node.value.loc)
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_IfExpT(self, node):
|
def visit_IfExpT(self, node):
|
||||||
self._unify(node.body.type, node.orelse.type,
|
self._unify(node.body.type, node.orelse.type,
|
||||||
node.body.loc, node.orelse.loc)
|
node.body.loc, node.orelse.loc)
|
||||||
node.type = node.body.type
|
node.type = node.body.type
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_BoolOpT(self, node):
|
def visit_BoolOpT(self, node):
|
||||||
for value in node.values:
|
for value in node.values:
|
||||||
self._unify(node.type, value.type,
|
self._unify(node.type, value.type,
|
||||||
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
|
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_UnaryOpT(self, node):
|
def visit_UnaryOpT(self, node):
|
||||||
if isinstance(node.op, ast.Not):
|
if isinstance(node.op, ast.Not):
|
||||||
|
@ -336,53 +371,34 @@ class Inferencer(algorithm.Transformer):
|
||||||
{"type": types.TypePrinter().name(operand_type)},
|
{"type": types.TypePrinter().name(operand_type)},
|
||||||
node.operand.loc)
|
node.operand.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_ModuleT(self, node):
|
|
||||||
self.env_stack.append(node.typing_env)
|
|
||||||
|
|
||||||
node = self.generic_visit(node)
|
|
||||||
|
|
||||||
self.env_stack.pop()
|
|
||||||
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_FunctionDefT(self, node):
|
def visit_FunctionDefT(self, node):
|
||||||
self.env_stack.append(node.typing_env)
|
|
||||||
old_function, self.function = self.function, node
|
old_function, self.function = self.function, node
|
||||||
|
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
|
|
||||||
self.function = old_function
|
self.function = old_function
|
||||||
self.env_stack.pop()
|
|
||||||
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
node = self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if len(node.targets) > 1:
|
if len(node.targets) > 1:
|
||||||
self._unify(builtins.TTuple([x.type for x in node.targets]), node.value.type,
|
self._unify(builtins.TTuple([x.type for x in node.targets]), node.value.type,
|
||||||
node.targets[0].loc.join(node.targets[-1].loc), node.value.loc)
|
node.targets[0].loc.join(node.targets[-1].loc), node.value.loc)
|
||||||
else:
|
else:
|
||||||
self._unify(node.targets[0].type, node.value.type,
|
self._unify(node.targets[0].type, node.value.type,
|
||||||
node.targets[0].loc, node.value.loc)
|
node.targets[0].loc, node.value.loc)
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_AugAssign(self, node):
|
def visit_AugAssign(self, node):
|
||||||
node = self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self._unify(node.target.type, node.value.type,
|
self._unify(node.target.type, node.value.type,
|
||||||
node.target.loc, node.value.loc)
|
node.target.loc, node.value.loc)
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_For(self, node):
|
||||||
node = self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
# TODO: support more than just lists
|
# TODO: support more than just lists
|
||||||
self._unify(builtins.TList(node.target.type), node.iter.type,
|
self._unify(builtins.TList(node.target.type), node.iter.type,
|
||||||
node.target.loc, node.iter.loc)
|
node.target.loc, node.iter.loc)
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
node = self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
def makenotes(printer, typea, typeb, loca, locb):
|
def makenotes(printer, typea, typeb, loca, locb):
|
||||||
return [
|
return [
|
||||||
diagnostic.Diagnostic("note",
|
diagnostic.Diagnostic("note",
|
||||||
|
@ -401,31 +417,6 @@ class Inferencer(algorithm.Transformer):
|
||||||
self._unify(self.function.return_type, node.value.type,
|
self._unify(self.function.return_type, node.value.type,
|
||||||
self.function.name_loc, node.value.loc, makenotes)
|
self.function.name_loc, node.value.loc, makenotes)
|
||||||
|
|
||||||
# Unsupported visitors
|
|
||||||
#
|
|
||||||
def visit_unsupported(self, node):
|
|
||||||
diag = diagnostic.Diagnostic("fatal",
|
|
||||||
"this syntax is not supported", {},
|
|
||||||
node.loc)
|
|
||||||
self.engine.process(diag)
|
|
||||||
|
|
||||||
visit_Attribute = visit_unsupported
|
|
||||||
visit_BinOp = visit_unsupported
|
|
||||||
visit_Call = visit_unsupported
|
|
||||||
visit_Compare = visit_unsupported
|
|
||||||
visit_Dict = visit_unsupported
|
|
||||||
visit_DictComp = visit_unsupported
|
|
||||||
visit_Ellipsis = visit_unsupported
|
|
||||||
visit_GeneratorExp = visit_unsupported
|
|
||||||
visit_Lambda = visit_unsupported
|
|
||||||
visit_ListComp = visit_unsupported
|
|
||||||
visit_Set = visit_unsupported
|
|
||||||
visit_SetComp = visit_unsupported
|
|
||||||
visit_Str = visit_unsupported
|
|
||||||
visit_Starred = visit_unsupported
|
|
||||||
visit_Yield = visit_unsupported
|
|
||||||
visit_YieldFrom = visit_unsupported
|
|
||||||
|
|
||||||
class Printer(algorithm.Visitor):
|
class Printer(algorithm.Visitor):
|
||||||
def __init__(self, buf):
|
def __init__(self, buf):
|
||||||
self.rewriter = source.Rewriter(buf)
|
self.rewriter = source.Rewriter(buf)
|
||||||
|
@ -467,7 +458,9 @@ def main():
|
||||||
|
|
||||||
buf = source.Buffer("".join(fileinput.input()), os.path.basename(fileinput.filename()))
|
buf = source.Buffer("".join(fileinput.input()), os.path.basename(fileinput.filename()))
|
||||||
parsed, comments = parse_buffer(buf, engine=engine)
|
parsed, comments = parse_buffer(buf, engine=engine)
|
||||||
typed = Inferencer(engine=engine).visit(parsed)
|
typed = ASTTypedRewriter(engine=engine).visit(parsed)
|
||||||
|
Inferencer(engine=engine).visit(typed)
|
||||||
|
|
||||||
printer = Printer(buf)
|
printer = Printer(buf)
|
||||||
printer.visit(typed)
|
printer.visit(typed)
|
||||||
for comment in comments:
|
for comment in comments:
|
||||||
|
|
Loading…
Reference in New Issue