diff --git a/artiq/py2llvm/asttyped.py b/artiq/py2llvm/asttyped.py new file mode 100644 index 000000000..21f3b07a1 --- /dev/null +++ b/artiq/py2llvm/asttyped.py @@ -0,0 +1,53 @@ +""" +The typedtree module exports the PythonParser AST enriched with +typing information. +""" + +from pythonparser import ast +from pythonparser.algorithm import Visitor as ASTVisitor + +class commontyped(ast.commonloc): + """A mixin for typed AST nodes.""" + + _types = ('type',) + + def _reprfields(self): + return self._fields + self._locs + self._types + +class scoped(object): + """ + :ivar typing_env: (dict with string keys and :class:`.types.Type` values) + map of variable names to variable types + :ivar globals_in_scope: (set of string keys) + list of variables resolved as globals + """ + +class ClassDefT(ast.ClassDef, scoped): + pass + +class FunctionDefT(ast.FunctionDef, scoped): + pass + +class LambdaT(ast.Lambda, scoped): + pass + +class DictCompT(ast.DictComp, scoped): + pass + +class ListCompT(ast.ListComp, scoped): + pass + +class SetCompT(ast.SetComp, scoped): + pass + +class argT(ast.arg, commontyped): + pass + +class NumT(ast.Num, commontyped): + pass + +class NameT(ast.Name, commontyped): + pass + +class NameConstantT(ast.NameConstant, commontyped): + pass diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py new file mode 100644 index 000000000..52cbe7518 --- /dev/null +++ b/artiq/py2llvm/types.py @@ -0,0 +1,189 @@ +""" +The :mod:`types` module contains the classes describing the types +in :mod:`asttyped`. +""" + +import string + +def genalnum(): + ident = ["a"] + while True: + yield "".join(ident) + pos = len(ident) - 1 + while pos >= 0: + cur_n = string.ascii_lowercase.index(ident[pos]) + if cur_n < 26: + ident[pos] = string.ascii_lowercase[cur_n + 1] + break + else: + ident[pos] = "a" + pos -= 1 + if pos < 0: + ident = "a" + ident + +class UnificationError(Exception): + def __init__(self, typea, typeb): + self.typea, self.typeb = typea, typeb + + +class Type(object): + pass + +class TVar(Type): + """ + A type variable. + + In effect, the classic union-find data structure is intrusively + folded into this class. + """ + + def __init__(self): + self.parent = self + + def find(self): + if self.parent is self: + return self + else: + root = self.parent.find() + self.parent = root # path compression + return root + + def unify(self, other): + other = other.find() + + if self.parent is self: + self.parent = other + else: + self.find().unify(other) + + def __repr__(self): + if self.parent is self: + return "TVar(%d)" % id(self) + else: + return repr(self.find()) + + # __eq__ and __hash__ are not overridden and default to + # comparison by identity. Use .find() explicitly before + # any lookups or comparisons. + +class TMono(Type): + """A monomorphic type, possibly parametric.""" + + def __init__(self, name, params={}): + self.name, self.params = name, params + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TMono) and self.name == other.name: + assert self.params.keys() == other.params.keys() + for param in self.params: + self.params[param].unify(other.params[param]) + else: + raise UnificationError(self, other) + + def __repr__(self): + return "TMono(%s, %s)" % (repr(self.name), repr(self.params)) + + def __eq__(self, other): + return isinstance(other, TMono) and \ + self.name == other.name and \ + self.params == other.params + + def __ne__(self, other): + return not (self == other) + +class TTuple(Type): + """A tuple type.""" + + def __init__(self, elts=[]): + self.elts = elts + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TTuple) and len(self.elts) == len(other.elts): + for selfelt, otherelt in zip(self.elts, other.elts): + selfelt.unify(otherelt) + else: + raise UnificationError(self, other) + + def __repr__(self): + return "TTuple(%s)" % (", ".join(map(repr, self.elts))) + + def __eq__(self, other): + return isinstance(other, TTuple) and \ + self.elts == other.elts + + def __ne__(self, other): + return not (self == other) + +class TValue(Type): + """ + A type-level value (such as the integer denoting width of + a generic integer type. + """ + + def __init__(self, value): + self.value = value + + def find(self): + return self + + def unify(self, other): + if self != other: + raise UnificationError(self, other) + + def __repr__(self): + return "TValue(%s)" % repr(self.value) + + def __eq__(self, other): + return isinstance(other, TValue) and \ + self.value == other.value + + def __ne__(self, other): + return not (self == other) + +def TBool(): + """A boolean type.""" + return TMono("bool") + +def TInt(width=TVar()): + """A generic integer type.""" + return TMono("int", {"width": width}) + +def TFloat(): + """A double-precision floating point type.""" + return TMono("float") + + +class TypePrinter(object): + """ + A class that prints types using Python-like syntax and gives + type variables sequential alphabetic names. + """ + + def __init__(self): + self.gen = genalnum() + self.map = {} + + def name(self, typ): + typ = typ.find() + if isinstance(typ, TVar): + if typ not in self.map: + self.map[typ] = "'%s" % next(self.gen) + return self.map[typ] + elif isinstance(typ, TMono): + return "%s(%s)" % (typ.name, ", ".join( + ["%s=%s" % (k, self.name(typ.params[k])) for k in typ.params])) + elif isinstance(typ, TTuple): + if len(typ.elts) == 1: + return "(%s,)" % self.name(typ.elts[0]) + else: + return "(%s)" % ", ".join(list(map(self.name, typ.elts))) + elif isinstance(typ, TValue): + return repr(typ.value) + else: + assert False diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py new file mode 100644 index 000000000..effbca78f --- /dev/null +++ b/artiq/py2llvm/typing.py @@ -0,0 +1,226 @@ +from pythonparser import source, ast, algorithm, diagnostic, parse_buffer +from . import asttyped, types + +# This visitor will be called for every node with a scope, +# i.e.: class, function, comprehension, lambda +class LocalExtractor(algorithm.Visitor): + def __init__(self, engine): + super().__init__() + self.engine = engine + + self.in_root = False + self.in_assign = False + self.typing_env = {} + + # which names are global have to be recorded in the current scope + self.global_ = set() + + # which names are nonlocal only affects whether the current scope + # gets a new binding or not, so we throw this away + self.nonlocal_ = set() + + # parameters can't be declared as global or nonlocal + self.params = set() + + def visit_in_assign(self, node): + try: + self.in_assign = True + return self.visit(node) + finally: + self.in_assign = False + + def visit_Assign(self, node): + for target in node.targets: + self.visit_in_assign(target) + self.visit(node.value) + + def visit_AugAssign(self, node): + self.visit_in_assign(node.target) + self.visit(node.op) + self.visit(node.value) + + def visit_For(self, node): + self.visit_in_assign(node.target) + self.visit(node.iter) + self.visit(node.body) + self.visit(node.orelse) + + def visit_withitem(self, node): + self.visit(node.context_expr) + if node.optional_vars is not None: + self.visit_in_assign(node.optional_vars) + + def visit_comprehension(self, node): + self.visit_in_assign(node.target) + self.visit(node.iter) + for if_ in node.ifs: + self.visit(node.ifs) + + def visit_root(self, node): + if self.in_root: + return + self.in_root = True + self.generic_visit(node) + + visit_ClassDef = visit_root # don't look at inner scopes + visit_FunctionDef = visit_root + visit_Lambda = visit_root + visit_DictComp = visit_root + visit_ListComp = visit_root + visit_SetComp = visit_root + + def _assignable(self, name): + if name not in self.typing_env and name not in self.nonlocal_: + self.typing_env[name] = types.TVar() + + def visit_arg(self, node): + self._assignable(node.arg) + self.params.add(node.arg) + + def visit_Name(self, node): + if self.in_assign: + # code like: + # x = 1 + # def f(): + # x = 1 + # creates a new binding for x in f's scope + self._assignable(node.id) + + def _check_not_in(self, name, names, curkind, newkind, loc): + if name in names: + diag = diagnostic.Diagnostic('fatal', + "name '{name}' cannot be {curkind} and {newkind} simultaneously", + {"name": name, "curkind": curkind, "newkind": newkind}, loc) + self.engine.process(diag) + + def visit_Global(self, node): + for name, loc in zip(node.names, node.name_locs): + self._check_not_in(name, self.nonlocal_, 'nonlocal', 'global', loc) + self._check_not_in(name, self.params, 'a parameter', 'global', loc) + self.global_.add(name) + + def visit_Nonlocal(self, node): + for name, loc in zip(node.names, node.name_locs): + self._check_not_in(name, self.global_, 'global', 'nonlocal', loc) + self._check_not_in(name, self.params, 'a parameter', 'nonlocal', loc) + self.nonlocal_.add(name) + + def visit_ExceptHandler(self, node): + self.visit(node.type) + self._assignable(node.name) + for stmt in node.body: + self.visit(stmt) + + +class Inferencer(algorithm.Transformer): + def __init__(self, engine): + self.engine = engine + self.env_stack = [{}] + + def _unify(self, typea, typeb, loca, locb): + try: + typea.unify(typeb) + except types.UnificationError as e: + note1 = diagnostic.Diagnostic('note', + "expression of type {typea}", + {"typea": types.TypePrinter().name(typea)}, + loca) + note2 = diagnostic.Diagnostic('note', + "expression of type {typeb}", + {"typeb": types.TypePrinter().name(typeb)}, + locb) + diag = diagnostic.Diagnostic('fatal', + "cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}", + {"typea": types.TypePrinter().name(typea), + "typeb": types.TypePrinter().name(typeb), + "fraga": types.TypePrinter().name(e.typea), + "fragb": types.TypePrinter().name(e.typeb),}, + loca, [locb], notes=[note1, note2]) + self.engine.process(diag) + + def visit_FunctionDef(self, node): + extractor = LocalExtractor(engine=self.engine) + extractor.visit(node) + + self.env_stack.append(extractor.typing_env) + node = asttyped.FunctionDefT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + name=node.name, args=self.visit(node.args), returns=self.visit(node.returns), + body=[self.visit(x) for x in node.body], decorator_list=node.decorator_list, + keyword_loc=node.keyword_loc, name_loc=node.name_loc, + arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, + loc=node.loc) + self.generic_visit(node) + self.env_stack.pop() + + return node + + def _find_name(self, name, loc): + for typing_env in reversed(self.env_stack): + if name in typing_env: + return typing_env[name] + diag = diagnostic.Diagnostic('fatal', + "name '{name}' is not bound to anything", {"name":name}, loc) + self.engine.process(diag) + + def visit_arg(self, node): + return asttyped.argT(type=self._find_name(node.arg, node.loc), + arg=node.arg, annotation=self.visit(node.annotation), + arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) + + def visit_Num(self, node): + if isinstance(node.n, int): + typ = types.TInt() + elif isinstance(node.n, float): + typ = types.TFloat() + else: + diag = diagnostic.Diagnostic('fatal', + "numeric type {type} is not supported", node.n.__class__.__name__, + node.loc) + self.engine.process(diag) + return asttyped.NumT(type=typ, + n=node.n, loc=node.loc) + + def visit_Name(self, node): + return asttyped.NameT(type=self._find_name(node.id, node.loc), + id=node.id, ctx=node.ctx, loc=node.loc) + + def visit_Assign(self, node): + node = self.generic_visit(node) + if len(node.targets) > 1: + self._unify(types.TTuple([x.type for x in node.targets]), node.value.type, + node.targets[0].loc.join(node.targets[-1].loc), node.value.loc) + else: + self._unify(node.targets[0].type, node.value.type, + node.targets[0].loc, node.value.loc) + return node + +class Printer(algorithm.Visitor): + def __init__(self, buf): + self.rewriter = source.Rewriter(buf) + self.type_printer = types.TypePrinter() + + def rewrite(self): + return self.rewriter.rewrite() + + def generic_visit(self, node): + if hasattr(node, 'type'): + self.rewriter.insert_after(node.loc, " : %s" % self.type_printer.name(node.type)) + + super().generic_visit(node) + +def main(): + import sys, fileinput + engine = diagnostic.Engine(all_errors_are_fatal=True) + try: + buf = source.Buffer("".join(fileinput.input()), fileinput.filename()) + parsed = parse_buffer(buf, engine=engine) + typed = Inferencer(engine=engine).visit(parsed) + printer = Printer(buf) + printer.visit(typed) + print(printer.rewrite().source) + except diagnostic.Error as e: + print("\n".join(e.diagnostic.render()), file=sys.stderr) + +if __name__ == "__main__": + main()