Separate inference and asttyped transformation.

This allows to run inference several times on the same tree,
as would be necessary when coercion nodes are added.
This commit is contained in:
whitequark 2015-06-11 06:34:22 +03:00
parent e18ea0daae
commit df686136f1
1 changed files with 65 additions and 57 deletions

View File

@ -166,6 +166,22 @@ class Inferencer(algorithm.Transformer):
loca, highlights, notes) loca, highlights, notes)
self.engine.process(diag) self.engine.process(diag)
# makenotes for the case where types of multiple elements are unified
# with the type of parent expression
def _makenotes_elts(self, elts, kind):
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"{kind} of type {typea}",
{"kind": kind, "typea": printer.name(elts[0].type)},
elts[0].loc),
diagnostic.Diagnostic("note",
"{kind} of type {typeb}",
{"kind": kind, "typeb": printer.name(typeb)},
locb)
]
return makenotes
def _find_name(self, name, loc): def _find_name(self, name, loc):
for typing_env in reversed(self.env_stack): for typing_env in reversed(self.env_stack):
if name in typing_env: if name in typing_env:
@ -212,26 +228,6 @@ class Inferencer(algorithm.Transformer):
return node return node
def visit_Return(self, node):
node = self.generic_visit(node)
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"function with return type {typea}",
{"typea": printer.name(typea)},
self.function.name_loc),
diagnostic.Diagnostic("note",
"a statement returning {typeb}",
{"typeb": printer.name(typeb)},
node.loc)
]
if node.value is None:
self._unify(self.function.return_type, types.TNone(),
self.function.name_loc, node.loc, makenotes)
else:
self._unify(self.function.return_type, node.value.type,
self.function.name_loc, node.value.loc, makenotes)
def visit_Num(self, node): def visit_Num(self, node):
if isinstance(node.n, int): if isinstance(node.n, int):
typ = types.TInt() typ = types.TInt()
@ -265,63 +261,55 @@ class Inferencer(algorithm.Transformer):
node = self.generic_visit(node) node = self.generic_visit(node)
node = asttyped.ListT(type=types.TList(), node = asttyped.ListT(type=types.TList(),
elts=node.elts, ctx=node.ctx, loc=node.loc) elts=node.elts, ctx=node.ctx, loc=node.loc)
def makenotes(printer, typea, typeb, loca, locb): return self.visit(node)
return [
diagnostic.Diagnostic("note",
"a list element of type {typea}",
{"typea": printer.name(node.elts[0].type)},
node.elts[0].loc),
diagnostic.Diagnostic("note",
"a list element of type {typeb}",
{"typeb": printer.name(typeb)},
locb)
]
for elt in node.elts:
self._unify(node.type["elt"], elt.type,
node.loc, elt.loc, makenotes)
return node
def visit_Subscript(self, node): def visit_Subscript(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
node = asttyped.SubscriptT(type=types.TVar(), node = asttyped.SubscriptT(type=types.TVar(),
value=node.value, slice=node.slice, ctx=node.ctx, value=node.value, slice=node.slice, ctx=node.ctx,
loc=node.loc) loc=node.loc)
# TODO: support more than just lists return self.visit(node)
self._unify(types.TList(node.type), node.value.type,
node.loc, node.value.loc)
return node
def visit_IfExp(self, node): def visit_IfExp(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
self._unify(node.body.type, node.orelse.type, node = asttyped.IfExpT(type=types.TVar(),
node.body.loc, node.orelse.loc)
return asttyped.IfExpT(type=node.body.type,
test=node.test, body=node.body, orelse=node.orelse, test=node.test, body=node.body, orelse=node.orelse,
if_loc=node.if_loc, else_loc=node.else_loc, loc=node.loc) if_loc=node.if_loc, else_loc=node.else_loc, loc=node.loc)
return self.visit(node)
def visit_BoolOp(self, node): def visit_BoolOp(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
node = asttyped.BoolOpT(type=types.TVar(), node = asttyped.BoolOpT(type=types.TVar(),
op=node.op, values=node.values, op=node.op, values=node.values,
op_locs=node.op_locs, loc=node.loc) op_locs=node.op_locs, loc=node.loc)
def makenotes(printer, typea, typeb, loca, locb): return self.visit(node)
return [
diagnostic.Diagnostic("note",
"an operand of type {typea}",
{"typea": printer.name(node.values[0].type)},
node.values[0].loc),
diagnostic.Diagnostic("note",
"an operand of type {typeb}",
{"typeb": printer.name(typeb)},
locb)
]
for value in node.values:
self._unify(node.type, value.type,
node.loc, value.loc, makenotes)
return node
# Visitors that just unify types # Visitors that just unify types
# #
def visit_ListT(self, node):
for elt in node.elts:
self._unify(node.type["elt"], elt.type,
node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element"))
return node
def visit_SubscriptT(self, node):
# TODO: support more than just lists
self._unify(types.TList(node.type), node.value.type,
node.loc, node.value.loc)
return node
def visit_IfExpT(self, node):
self._unify(node.body.type, node.orelse.type,
node.body.loc, node.orelse.loc)
node.type = node.body.type
return node
def visit_BoolOpT(self, node):
for value in node.values:
self._unify(node.type, value.type,
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
return node
def visit_Assign(self, node): def visit_Assign(self, node):
node = self.generic_visit(node) node = self.generic_visit(node)
if len(node.targets) > 1: if len(node.targets) > 1:
@ -345,6 +333,26 @@ class Inferencer(algorithm.Transformer):
node.target.loc, node.iter.loc) node.target.loc, node.iter.loc)
return node return node
def visit_Return(self, node):
node = self.generic_visit(node)
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"function with return type {typea}",
{"typea": printer.name(typea)},
self.function.name_loc),
diagnostic.Diagnostic("note",
"a statement returning {typeb}",
{"typeb": printer.name(typeb)},
node.loc)
]
if node.value is None:
self._unify(self.function.return_type, types.TNone(),
self.function.name_loc, node.loc, makenotes)
else:
self._unify(self.function.return_type, node.value.type,
self.function.name_loc, node.value.loc, makenotes)
# Unsupported visitors # Unsupported visitors
# #
def visit_unsupported(self, node): def visit_unsupported(self, node):