From 995d84d4eed083085105647aeadb746e3445cd75 Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 4 Jun 2015 14:12:41 +0300 Subject: [PATCH] Add inferencing for Tuple, List, For. --- artiq/py2llvm/asttyped.py | 69 +++++++++++++++++++++++++++------------ artiq/py2llvm/types.py | 13 +++++++- artiq/py2llvm/typing.py | 27 ++++++++++++++- 3 files changed, 87 insertions(+), 22 deletions(-) diff --git a/artiq/py2llvm/asttyped.py b/artiq/py2llvm/asttyped.py index 21f3b07a1..24d341faa 100644 --- a/artiq/py2llvm/asttyped.py +++ b/artiq/py2llvm/asttyped.py @@ -22,32 +22,61 @@ class scoped(object): list of variables resolved as globals """ -class ClassDefT(ast.ClassDef, scoped): - pass - -class FunctionDefT(ast.FunctionDef, scoped): - pass - -class LambdaT(ast.Lambda, scoped): - pass - -class DictCompT(ast.DictComp, scoped): - pass - -class ListCompT(ast.ListComp, scoped): - pass - -class SetCompT(ast.SetComp, scoped): - pass - class argT(ast.arg, commontyped): pass -class NumT(ast.Num, commontyped): +class ClassDefT(ast.ClassDef, scoped): + pass +class FunctionDefT(ast.FunctionDef, scoped): pass +class AttributeT(ast.Attribute, commontyped): + pass +class BinOpT(ast.BinOp, commontyped): + pass +class BoolOpT(ast.BoolOp, commontyped): + pass +class CallT(ast.Call, commontyped): + pass +class CompareT(ast.Compare, commontyped): + pass +class DictT(ast.Dict, commontyped): + pass +class DictCompT(ast.DictComp, commontyped, scoped): + pass +class EllipsisT(ast.Ellipsis, commontyped): + pass +class GeneratorExpT(ast.GeneratorExp, commontyped, scoped): + pass +class IfExpT(ast.IfExp, commontyped): + pass +class LambdaT(ast.Lambda, commontyped, scoped): + pass +class ListT(ast.List, commontyped): + pass +class ListCompT(ast.ListComp, commontyped, scoped): + pass class NameT(ast.Name, commontyped): pass - class NameConstantT(ast.NameConstant, commontyped): pass +class NumT(ast.Num, commontyped): + pass +class SetT(ast.Set, commontyped): + pass +class SetCompT(ast.SetComp, commontyped, scoped): + pass +class StrT(ast.Str, commontyped): + pass +class StarredT(ast.Starred, commontyped): + pass +class SubscriptT(ast.Subscript, commontyped): + pass +class TupleT(ast.Tuple, commontyped): + pass +class UnaryOpT(ast.UnaryOp, commontyped): + pass +class YieldT(ast.Yield, commontyped): + pass +class YieldFromT(ast.YieldFrom, commontyped): + pass diff --git a/artiq/py2llvm/types.py b/artiq/py2llvm/types.py index 52cbe7518..ff17c8456 100644 --- a/artiq/py2llvm/types.py +++ b/artiq/py2llvm/types.py @@ -86,6 +86,9 @@ class TMono(Type): def __repr__(self): return "TMono(%s, %s)" % (repr(self.name), repr(self.params)) + def __getitem__(self, param): + return self.params[param] + def __eq__(self, other): return isinstance(other, TMono) and \ self.name == other.name and \ @@ -150,14 +153,22 @@ def TBool(): """A boolean type.""" return TMono("bool") -def TInt(width=TVar()): +def TInt(width=None): """A generic integer type.""" + if width is None: + width = TVar() return TMono("int", {"width": width}) def TFloat(): """A double-precision floating point type.""" return TMono("float") +def TList(elt=None): + """A generic list type.""" + if elt is None: + elt = TVar() + return TMono("list", {"elt": elt}) + class TypePrinter(object): """ diff --git a/artiq/py2llvm/typing.py b/artiq/py2llvm/typing.py index c205ec984..47c0473bb 100644 --- a/artiq/py2llvm/typing.py +++ b/artiq/py2llvm/typing.py @@ -184,6 +184,8 @@ class Inferencer(algorithm.Transformer): "name '{name}' is not bound to anything", {"name":name}, loc) self.engine.process(diag) + # Visitors that replace node with a typed node + # def visit_arg(self, node): return asttyped.argT(type=self._find_name(node.arg, node.loc), arg=node.arg, annotation=self.visit(node.annotation), @@ -206,6 +208,22 @@ class Inferencer(algorithm.Transformer): return asttyped.NameT(type=self._find_name(node.id, node.loc), id=node.id, ctx=node.ctx, loc=node.loc) + def visit_Tuple(self, node): + node = self.generic_visit(node) + return asttyped.TupleT(type=types.TTuple([x.type for x in node.elts]), + elts=node.elts, ctx=node.ctx, loc=node.loc) + + def visit_List(self, node): + node = self.generic_visit(node) + node = asttyped.ListT(type=types.TList(), + elts=node.elts, ctx=node.ctx, loc=node.loc) + for elt in node.elts: + self._unify(node.type['elt'], elt.type, + node.loc, elt.loc) + return node + + # Visitors that just unify types + # def visit_Assign(self, node): node = self.generic_visit(node) if len(node.targets) > 1: @@ -222,6 +240,13 @@ class Inferencer(algorithm.Transformer): node.target.loc, node.value.loc) return node + def visit_For(self, node): + node = self.generic_visit(node) + # TODO: support more than just lists + self._unify(TList(node.target.type), node.iter.type, + node.target.loc, node.iter.loc) + return node + class Printer(algorithm.Visitor): def __init__(self, buf): self.rewriter = source.Rewriter(buf) @@ -232,7 +257,7 @@ class Printer(algorithm.Visitor): def generic_visit(self, node): if hasattr(node, 'type'): - self.rewriter.insert_after(node.loc, " : %s" % self.type_printer.name(node.type)) + self.rewriter.insert_after(node.loc, ":%s" % self.type_printer.name(node.type)) super().generic_visit(node)