forked from M-Labs/artiq
Add new type inferencer.
This commit is contained in:
parent
74080f2cc6
commit
abbc87e981
|
@ -0,0 +1,53 @@
|
||||||
|
"""
|
||||||
|
The typedtree module exports the PythonParser AST enriched with
|
||||||
|
typing information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pythonparser import ast
|
||||||
|
from pythonparser.algorithm import Visitor as ASTVisitor
|
||||||
|
|
||||||
|
class commontyped(ast.commonloc):
|
||||||
|
"""A mixin for typed AST nodes."""
|
||||||
|
|
||||||
|
_types = ('type',)
|
||||||
|
|
||||||
|
def _reprfields(self):
|
||||||
|
return self._fields + self._locs + self._types
|
||||||
|
|
||||||
|
class scoped(object):
|
||||||
|
"""
|
||||||
|
:ivar typing_env: (dict with string keys and :class:`.types.Type` values)
|
||||||
|
map of variable names to variable types
|
||||||
|
:ivar globals_in_scope: (set of string keys)
|
||||||
|
list of variables resolved as globals
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ClassDefT(ast.ClassDef, scoped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FunctionDefT(ast.FunctionDef, scoped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LambdaT(ast.Lambda, scoped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DictCompT(ast.DictComp, scoped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ListCompT(ast.ListComp, scoped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SetCompT(ast.SetComp, scoped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class argT(ast.arg, commontyped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NumT(ast.Num, commontyped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NameT(ast.Name, commontyped):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NameConstantT(ast.NameConstant, commontyped):
|
||||||
|
pass
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""
|
||||||
|
The :mod:`types` module contains the classes describing the types
|
||||||
|
in :mod:`asttyped`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import string
|
||||||
|
|
||||||
|
def genalnum():
|
||||||
|
ident = ["a"]
|
||||||
|
while True:
|
||||||
|
yield "".join(ident)
|
||||||
|
pos = len(ident) - 1
|
||||||
|
while pos >= 0:
|
||||||
|
cur_n = string.ascii_lowercase.index(ident[pos])
|
||||||
|
if cur_n < 26:
|
||||||
|
ident[pos] = string.ascii_lowercase[cur_n + 1]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
ident[pos] = "a"
|
||||||
|
pos -= 1
|
||||||
|
if pos < 0:
|
||||||
|
ident = "a" + ident
|
||||||
|
|
||||||
|
class UnificationError(Exception):
|
||||||
|
def __init__(self, typea, typeb):
|
||||||
|
self.typea, self.typeb = typea, typeb
|
||||||
|
|
||||||
|
|
||||||
|
class Type(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TVar(Type):
|
||||||
|
"""
|
||||||
|
A type variable.
|
||||||
|
|
||||||
|
In effect, the classic union-find data structure is intrusively
|
||||||
|
folded into this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.parent = self
|
||||||
|
|
||||||
|
def find(self):
|
||||||
|
if self.parent is self:
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
root = self.parent.find()
|
||||||
|
self.parent = root # path compression
|
||||||
|
return root
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
other = other.find()
|
||||||
|
|
||||||
|
if self.parent is self:
|
||||||
|
self.parent = other
|
||||||
|
else:
|
||||||
|
self.find().unify(other)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.parent is self:
|
||||||
|
return "TVar(%d)" % id(self)
|
||||||
|
else:
|
||||||
|
return repr(self.find())
|
||||||
|
|
||||||
|
# __eq__ and __hash__ are not overridden and default to
|
||||||
|
# comparison by identity. Use .find() explicitly before
|
||||||
|
# any lookups or comparisons.
|
||||||
|
|
||||||
|
class TMono(Type):
|
||||||
|
"""A monomorphic type, possibly parametric."""
|
||||||
|
|
||||||
|
def __init__(self, name, params={}):
|
||||||
|
self.name, self.params = name, params
|
||||||
|
|
||||||
|
def find(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
if isinstance(other, TMono) and self.name == other.name:
|
||||||
|
assert self.params.keys() == other.params.keys()
|
||||||
|
for param in self.params:
|
||||||
|
self.params[param].unify(other.params[param])
|
||||||
|
else:
|
||||||
|
raise UnificationError(self, other)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "TMono(%s, %s)" % (repr(self.name), repr(self.params))
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, TMono) and \
|
||||||
|
self.name == other.name and \
|
||||||
|
self.params == other.params
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
||||||
|
|
||||||
|
class TTuple(Type):
|
||||||
|
"""A tuple type."""
|
||||||
|
|
||||||
|
def __init__(self, elts=[]):
|
||||||
|
self.elts = elts
|
||||||
|
|
||||||
|
def find(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
if isinstance(other, TTuple) and len(self.elts) == len(other.elts):
|
||||||
|
for selfelt, otherelt in zip(self.elts, other.elts):
|
||||||
|
selfelt.unify(otherelt)
|
||||||
|
else:
|
||||||
|
raise UnificationError(self, other)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "TTuple(%s)" % (", ".join(map(repr, self.elts)))
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, TTuple) and \
|
||||||
|
self.elts == other.elts
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
||||||
|
|
||||||
|
class TValue(Type):
|
||||||
|
"""
|
||||||
|
A type-level value (such as the integer denoting width of
|
||||||
|
a generic integer type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def find(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
if self != other:
|
||||||
|
raise UnificationError(self, other)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "TValue(%s)" % repr(self.value)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, TValue) and \
|
||||||
|
self.value == other.value
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
||||||
|
|
||||||
|
def TBool():
|
||||||
|
"""A boolean type."""
|
||||||
|
return TMono("bool")
|
||||||
|
|
||||||
|
def TInt(width=TVar()):
|
||||||
|
"""A generic integer type."""
|
||||||
|
return TMono("int", {"width": width})
|
||||||
|
|
||||||
|
def TFloat():
|
||||||
|
"""A double-precision floating point type."""
|
||||||
|
return TMono("float")
|
||||||
|
|
||||||
|
|
||||||
|
class TypePrinter(object):
|
||||||
|
"""
|
||||||
|
A class that prints types using Python-like syntax and gives
|
||||||
|
type variables sequential alphabetic names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.gen = genalnum()
|
||||||
|
self.map = {}
|
||||||
|
|
||||||
|
def name(self, typ):
|
||||||
|
typ = typ.find()
|
||||||
|
if isinstance(typ, TVar):
|
||||||
|
if typ not in self.map:
|
||||||
|
self.map[typ] = "'%s" % next(self.gen)
|
||||||
|
return self.map[typ]
|
||||||
|
elif isinstance(typ, TMono):
|
||||||
|
return "%s(%s)" % (typ.name, ", ".join(
|
||||||
|
["%s=%s" % (k, self.name(typ.params[k])) for k in typ.params]))
|
||||||
|
elif isinstance(typ, TTuple):
|
||||||
|
if len(typ.elts) == 1:
|
||||||
|
return "(%s,)" % self.name(typ.elts[0])
|
||||||
|
else:
|
||||||
|
return "(%s)" % ", ".join(list(map(self.name, typ.elts)))
|
||||||
|
elif isinstance(typ, TValue):
|
||||||
|
return repr(typ.value)
|
||||||
|
else:
|
||||||
|
assert False
|
|
@ -0,0 +1,226 @@
|
||||||
|
from pythonparser import source, ast, algorithm, diagnostic, parse_buffer
|
||||||
|
from . import asttyped, types
|
||||||
|
|
||||||
|
# This visitor will be called for every node with a scope,
|
||||||
|
# i.e.: class, function, comprehension, lambda
|
||||||
|
class LocalExtractor(algorithm.Visitor):
|
||||||
|
def __init__(self, engine):
|
||||||
|
super().__init__()
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
self.in_root = False
|
||||||
|
self.in_assign = False
|
||||||
|
self.typing_env = {}
|
||||||
|
|
||||||
|
# which names are global have to be recorded in the current scope
|
||||||
|
self.global_ = set()
|
||||||
|
|
||||||
|
# which names are nonlocal only affects whether the current scope
|
||||||
|
# gets a new binding or not, so we throw this away
|
||||||
|
self.nonlocal_ = set()
|
||||||
|
|
||||||
|
# parameters can't be declared as global or nonlocal
|
||||||
|
self.params = set()
|
||||||
|
|
||||||
|
def visit_in_assign(self, node):
|
||||||
|
try:
|
||||||
|
self.in_assign = True
|
||||||
|
return self.visit(node)
|
||||||
|
finally:
|
||||||
|
self.in_assign = False
|
||||||
|
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
for target in node.targets:
|
||||||
|
self.visit_in_assign(target)
|
||||||
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_AugAssign(self, node):
|
||||||
|
self.visit_in_assign(node.target)
|
||||||
|
self.visit(node.op)
|
||||||
|
self.visit(node.value)
|
||||||
|
|
||||||
|
def visit_For(self, node):
|
||||||
|
self.visit_in_assign(node.target)
|
||||||
|
self.visit(node.iter)
|
||||||
|
self.visit(node.body)
|
||||||
|
self.visit(node.orelse)
|
||||||
|
|
||||||
|
def visit_withitem(self, node):
|
||||||
|
self.visit(node.context_expr)
|
||||||
|
if node.optional_vars is not None:
|
||||||
|
self.visit_in_assign(node.optional_vars)
|
||||||
|
|
||||||
|
def visit_comprehension(self, node):
|
||||||
|
self.visit_in_assign(node.target)
|
||||||
|
self.visit(node.iter)
|
||||||
|
for if_ in node.ifs:
|
||||||
|
self.visit(node.ifs)
|
||||||
|
|
||||||
|
def visit_root(self, node):
|
||||||
|
if self.in_root:
|
||||||
|
return
|
||||||
|
self.in_root = True
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
visit_ClassDef = visit_root # don't look at inner scopes
|
||||||
|
visit_FunctionDef = visit_root
|
||||||
|
visit_Lambda = visit_root
|
||||||
|
visit_DictComp = visit_root
|
||||||
|
visit_ListComp = visit_root
|
||||||
|
visit_SetComp = visit_root
|
||||||
|
|
||||||
|
def _assignable(self, name):
|
||||||
|
if name not in self.typing_env and name not in self.nonlocal_:
|
||||||
|
self.typing_env[name] = types.TVar()
|
||||||
|
|
||||||
|
def visit_arg(self, node):
|
||||||
|
self._assignable(node.arg)
|
||||||
|
self.params.add(node.arg)
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
if self.in_assign:
|
||||||
|
# code like:
|
||||||
|
# x = 1
|
||||||
|
# def f():
|
||||||
|
# x = 1
|
||||||
|
# creates a new binding for x in f's scope
|
||||||
|
self._assignable(node.id)
|
||||||
|
|
||||||
|
def _check_not_in(self, name, names, curkind, newkind, loc):
|
||||||
|
if name in names:
|
||||||
|
diag = diagnostic.Diagnostic('fatal',
|
||||||
|
"name '{name}' cannot be {curkind} and {newkind} simultaneously",
|
||||||
|
{"name": name, "curkind": curkind, "newkind": newkind}, loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
def visit_Global(self, node):
|
||||||
|
for name, loc in zip(node.names, node.name_locs):
|
||||||
|
self._check_not_in(name, self.nonlocal_, 'nonlocal', 'global', loc)
|
||||||
|
self._check_not_in(name, self.params, 'a parameter', 'global', loc)
|
||||||
|
self.global_.add(name)
|
||||||
|
|
||||||
|
def visit_Nonlocal(self, node):
|
||||||
|
for name, loc in zip(node.names, node.name_locs):
|
||||||
|
self._check_not_in(name, self.global_, 'global', 'nonlocal', loc)
|
||||||
|
self._check_not_in(name, self.params, 'a parameter', 'nonlocal', loc)
|
||||||
|
self.nonlocal_.add(name)
|
||||||
|
|
||||||
|
def visit_ExceptHandler(self, node):
|
||||||
|
self.visit(node.type)
|
||||||
|
self._assignable(node.name)
|
||||||
|
for stmt in node.body:
|
||||||
|
self.visit(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
class Inferencer(algorithm.Transformer):
|
||||||
|
def __init__(self, engine):
|
||||||
|
self.engine = engine
|
||||||
|
self.env_stack = [{}]
|
||||||
|
|
||||||
|
def _unify(self, typea, typeb, loca, locb):
|
||||||
|
try:
|
||||||
|
typea.unify(typeb)
|
||||||
|
except types.UnificationError as e:
|
||||||
|
note1 = diagnostic.Diagnostic('note',
|
||||||
|
"expression of type {typea}",
|
||||||
|
{"typea": types.TypePrinter().name(typea)},
|
||||||
|
loca)
|
||||||
|
note2 = diagnostic.Diagnostic('note',
|
||||||
|
"expression of type {typeb}",
|
||||||
|
{"typeb": types.TypePrinter().name(typeb)},
|
||||||
|
locb)
|
||||||
|
diag = diagnostic.Diagnostic('fatal',
|
||||||
|
"cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}",
|
||||||
|
{"typea": types.TypePrinter().name(typea),
|
||||||
|
"typeb": types.TypePrinter().name(typeb),
|
||||||
|
"fraga": types.TypePrinter().name(e.typea),
|
||||||
|
"fragb": types.TypePrinter().name(e.typeb),},
|
||||||
|
loca, [locb], notes=[note1, note2])
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
extractor = LocalExtractor(engine=self.engine)
|
||||||
|
extractor.visit(node)
|
||||||
|
|
||||||
|
self.env_stack.append(extractor.typing_env)
|
||||||
|
node = asttyped.FunctionDefT(
|
||||||
|
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
|
||||||
|
name=node.name, args=self.visit(node.args), returns=self.visit(node.returns),
|
||||||
|
body=[self.visit(x) for x in node.body], decorator_list=node.decorator_list,
|
||||||
|
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
|
||||||
|
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
|
||||||
|
loc=node.loc)
|
||||||
|
self.generic_visit(node)
|
||||||
|
self.env_stack.pop()
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _find_name(self, name, loc):
|
||||||
|
for typing_env in reversed(self.env_stack):
|
||||||
|
if name in typing_env:
|
||||||
|
return typing_env[name]
|
||||||
|
diag = diagnostic.Diagnostic('fatal',
|
||||||
|
"name '{name}' is not bound to anything", {"name":name}, loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
def visit_arg(self, node):
|
||||||
|
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
||||||
|
arg=node.arg, annotation=self.visit(node.annotation),
|
||||||
|
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
|
||||||
|
|
||||||
|
def visit_Num(self, node):
|
||||||
|
if isinstance(node.n, int):
|
||||||
|
typ = types.TInt()
|
||||||
|
elif isinstance(node.n, float):
|
||||||
|
typ = types.TFloat()
|
||||||
|
else:
|
||||||
|
diag = diagnostic.Diagnostic('fatal',
|
||||||
|
"numeric type {type} is not supported", node.n.__class__.__name__,
|
||||||
|
node.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
return asttyped.NumT(type=typ,
|
||||||
|
n=node.n, loc=node.loc)
|
||||||
|
|
||||||
|
def visit_Name(self, node):
|
||||||
|
return asttyped.NameT(type=self._find_name(node.id, node.loc),
|
||||||
|
id=node.id, ctx=node.ctx, loc=node.loc)
|
||||||
|
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
if len(node.targets) > 1:
|
||||||
|
self._unify(types.TTuple([x.type for x in node.targets]), node.value.type,
|
||||||
|
node.targets[0].loc.join(node.targets[-1].loc), node.value.loc)
|
||||||
|
else:
|
||||||
|
self._unify(node.targets[0].type, node.value.type,
|
||||||
|
node.targets[0].loc, node.value.loc)
|
||||||
|
return node
|
||||||
|
|
||||||
|
class Printer(algorithm.Visitor):
|
||||||
|
def __init__(self, buf):
|
||||||
|
self.rewriter = source.Rewriter(buf)
|
||||||
|
self.type_printer = types.TypePrinter()
|
||||||
|
|
||||||
|
def rewrite(self):
|
||||||
|
return self.rewriter.rewrite()
|
||||||
|
|
||||||
|
def generic_visit(self, node):
|
||||||
|
if hasattr(node, 'type'):
|
||||||
|
self.rewriter.insert_after(node.loc, " : %s" % self.type_printer.name(node.type))
|
||||||
|
|
||||||
|
super().generic_visit(node)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import sys, fileinput
|
||||||
|
engine = diagnostic.Engine(all_errors_are_fatal=True)
|
||||||
|
try:
|
||||||
|
buf = source.Buffer("".join(fileinput.input()), fileinput.filename())
|
||||||
|
parsed = parse_buffer(buf, engine=engine)
|
||||||
|
typed = Inferencer(engine=engine).visit(parsed)
|
||||||
|
printer = Printer(buf)
|
||||||
|
printer.visit(typed)
|
||||||
|
print(printer.rewrite().source)
|
||||||
|
except diagnostic.Error as e:
|
||||||
|
print("\n".join(e.diagnostic.render()), file=sys.stderr)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue