forked from M-Labs/artiq
Add inferencing for Tuple, List, For.
This commit is contained in:
parent
76ce364fea
commit
995d84d4ee
@ -22,32 +22,61 @@ class scoped(object):
|
|||||||
list of variables resolved as globals
|
list of variables resolved as globals
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ClassDefT(ast.ClassDef, scoped):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class FunctionDefT(ast.FunctionDef, scoped):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class LambdaT(ast.Lambda, scoped):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class DictCompT(ast.DictComp, scoped):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ListCompT(ast.ListComp, scoped):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class SetCompT(ast.SetComp, scoped):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class argT(ast.arg, commontyped):
|
class argT(ast.arg, commontyped):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class NumT(ast.Num, commontyped):
|
class ClassDefT(ast.ClassDef, scoped):
|
||||||
|
pass
|
||||||
|
class FunctionDefT(ast.FunctionDef, scoped):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class AttributeT(ast.Attribute, commontyped):
|
||||||
|
pass
|
||||||
|
class BinOpT(ast.BinOp, commontyped):
|
||||||
|
pass
|
||||||
|
class BoolOpT(ast.BoolOp, commontyped):
|
||||||
|
pass
|
||||||
|
class CallT(ast.Call, commontyped):
|
||||||
|
pass
|
||||||
|
class CompareT(ast.Compare, commontyped):
|
||||||
|
pass
|
||||||
|
class DictT(ast.Dict, commontyped):
|
||||||
|
pass
|
||||||
|
class DictCompT(ast.DictComp, commontyped, scoped):
|
||||||
|
pass
|
||||||
|
class EllipsisT(ast.Ellipsis, commontyped):
|
||||||
|
pass
|
||||||
|
class GeneratorExpT(ast.GeneratorExp, commontyped, scoped):
|
||||||
|
pass
|
||||||
|
class IfExpT(ast.IfExp, commontyped):
|
||||||
|
pass
|
||||||
|
class LambdaT(ast.Lambda, commontyped, scoped):
|
||||||
|
pass
|
||||||
|
class ListT(ast.List, commontyped):
|
||||||
|
pass
|
||||||
|
class ListCompT(ast.ListComp, commontyped, scoped):
|
||||||
|
pass
|
||||||
class NameT(ast.Name, commontyped):
|
class NameT(ast.Name, commontyped):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class NameConstantT(ast.NameConstant, commontyped):
|
class NameConstantT(ast.NameConstant, commontyped):
|
||||||
pass
|
pass
|
||||||
|
class NumT(ast.Num, commontyped):
|
||||||
|
pass
|
||||||
|
class SetT(ast.Set, commontyped):
|
||||||
|
pass
|
||||||
|
class SetCompT(ast.SetComp, commontyped, scoped):
|
||||||
|
pass
|
||||||
|
class StrT(ast.Str, commontyped):
|
||||||
|
pass
|
||||||
|
class StarredT(ast.Starred, commontyped):
|
||||||
|
pass
|
||||||
|
class SubscriptT(ast.Subscript, commontyped):
|
||||||
|
pass
|
||||||
|
class TupleT(ast.Tuple, commontyped):
|
||||||
|
pass
|
||||||
|
class UnaryOpT(ast.UnaryOp, commontyped):
|
||||||
|
pass
|
||||||
|
class YieldT(ast.Yield, commontyped):
|
||||||
|
pass
|
||||||
|
class YieldFromT(ast.YieldFrom, commontyped):
|
||||||
|
pass
|
||||||
|
@ -86,6 +86,9 @@ class TMono(Type):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "TMono(%s, %s)" % (repr(self.name), repr(self.params))
|
return "TMono(%s, %s)" % (repr(self.name), repr(self.params))
|
||||||
|
|
||||||
|
def __getitem__(self, param):
|
||||||
|
return self.params[param]
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, TMono) and \
|
return isinstance(other, TMono) and \
|
||||||
self.name == other.name and \
|
self.name == other.name and \
|
||||||
@ -150,14 +153,22 @@ def TBool():
|
|||||||
"""A boolean type."""
|
"""A boolean type."""
|
||||||
return TMono("bool")
|
return TMono("bool")
|
||||||
|
|
||||||
def TInt(width=TVar()):
|
def TInt(width=None):
|
||||||
"""A generic integer type."""
|
"""A generic integer type."""
|
||||||
|
if width is None:
|
||||||
|
width = TVar()
|
||||||
return TMono("int", {"width": width})
|
return TMono("int", {"width": width})
|
||||||
|
|
||||||
def TFloat():
|
def TFloat():
|
||||||
"""A double-precision floating point type."""
|
"""A double-precision floating point type."""
|
||||||
return TMono("float")
|
return TMono("float")
|
||||||
|
|
||||||
|
def TList(elt=None):
|
||||||
|
"""A generic list type."""
|
||||||
|
if elt is None:
|
||||||
|
elt = TVar()
|
||||||
|
return TMono("list", {"elt": elt})
|
||||||
|
|
||||||
|
|
||||||
class TypePrinter(object):
|
class TypePrinter(object):
|
||||||
"""
|
"""
|
||||||
|
@ -184,6 +184,8 @@ class Inferencer(algorithm.Transformer):
|
|||||||
"name '{name}' is not bound to anything", {"name":name}, loc)
|
"name '{name}' is not bound to anything", {"name":name}, loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
|
|
||||||
|
# Visitors that replace node with a typed node
|
||||||
|
#
|
||||||
def visit_arg(self, node):
|
def visit_arg(self, node):
|
||||||
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
return asttyped.argT(type=self._find_name(node.arg, node.loc),
|
||||||
arg=node.arg, annotation=self.visit(node.annotation),
|
arg=node.arg, annotation=self.visit(node.annotation),
|
||||||
@ -206,6 +208,22 @@ class Inferencer(algorithm.Transformer):
|
|||||||
return asttyped.NameT(type=self._find_name(node.id, node.loc),
|
return asttyped.NameT(type=self._find_name(node.id, node.loc),
|
||||||
id=node.id, ctx=node.ctx, loc=node.loc)
|
id=node.id, ctx=node.ctx, loc=node.loc)
|
||||||
|
|
||||||
|
def visit_Tuple(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
return asttyped.TupleT(type=types.TTuple([x.type for x in node.elts]),
|
||||||
|
elts=node.elts, ctx=node.ctx, loc=node.loc)
|
||||||
|
|
||||||
|
def visit_List(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
node = asttyped.ListT(type=types.TList(),
|
||||||
|
elts=node.elts, ctx=node.ctx, loc=node.loc)
|
||||||
|
for elt in node.elts:
|
||||||
|
self._unify(node.type['elt'], elt.type,
|
||||||
|
node.loc, elt.loc)
|
||||||
|
return node
|
||||||
|
|
||||||
|
# Visitors that just unify types
|
||||||
|
#
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
if len(node.targets) > 1:
|
if len(node.targets) > 1:
|
||||||
@ -222,6 +240,13 @@ class Inferencer(algorithm.Transformer):
|
|||||||
node.target.loc, node.value.loc)
|
node.target.loc, node.value.loc)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_For(self, node):
|
||||||
|
node = self.generic_visit(node)
|
||||||
|
# TODO: support more than just lists
|
||||||
|
self._unify(TList(node.target.type), node.iter.type,
|
||||||
|
node.target.loc, node.iter.loc)
|
||||||
|
return node
|
||||||
|
|
||||||
class Printer(algorithm.Visitor):
|
class Printer(algorithm.Visitor):
|
||||||
def __init__(self, buf):
|
def __init__(self, buf):
|
||||||
self.rewriter = source.Rewriter(buf)
|
self.rewriter = source.Rewriter(buf)
|
||||||
@ -232,7 +257,7 @@ class Printer(algorithm.Visitor):
|
|||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node):
|
||||||
if hasattr(node, 'type'):
|
if hasattr(node, 'type'):
|
||||||
self.rewriter.insert_after(node.loc, " : %s" % self.type_printer.name(node.type))
|
self.rewriter.insert_after(node.loc, ":%s" % self.type_printer.name(node.type))
|
||||||
|
|
||||||
super().generic_visit(node)
|
super().generic_visit(node)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user