1
0
forked from M-Labs/artiq

Add support for BinOp.

This commit is contained in:
whitequark 2015-06-14 12:07:13 +03:00
parent faaf189961
commit 7b78e7de67
7 changed files with 325 additions and 43 deletions

View File

@ -22,6 +22,7 @@ class scoped(object):
list of variables resolved as globals
"""
# Typed versions of untyped nodes
class argT(ast.arg, commontyped):
pass
@ -82,3 +83,7 @@ class YieldT(ast.Yield, commontyped):
pass
class YieldFromT(ast.YieldFrom, commontyped):
pass
# Novel typed nodes
class CoerceT(ast.expr, commontyped):
_fields = ('expr',) # other_expr deliberately not in _fields

View File

@ -23,36 +23,6 @@ class TFloat(types.TMono):
def __init__(self):
super().__init__("float")
class TTuple(types.Type):
"""A tuple type."""
attributes = {}
def __init__(self, elts=[]):
self.elts = elts
def find(self):
return self
def unify(self, other):
if isinstance(other, TTuple) and len(self.elts) == len(other.elts):
for selfelt, otherelt in zip(self.elts, other.elts):
selfelt.unify(otherelt)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def __repr__(self):
return "TTuple(%s)" % (", ".join(map(repr, self.elts)))
def __eq__(self, other):
return isinstance(other, TTuple) and \
self.elts == other.elts
def __ne__(self, other):
return not (self == other)
class TList(types.TMono):
def __init__(self, elt=None):
if elt is None:
@ -60,12 +30,37 @@ class TList(types.TMono):
super().__init__("list", {"elt": elt})
def is_none(typ):
return types.is_mono(typ, "NoneType")
def is_bool(typ):
return types.is_mono(typ, "bool")
def is_int(typ, width=None):
if width:
return types.is_mono(typ, "int", {"width": width})
else:
return types.is_mono(typ, "int")
def get_int_width(typ):
if is_int(typ):
return types.get_value(typ["width"])
def is_float(typ):
return types.is_mono(typ, "float")
def is_numeric(typ):
typ = typ.find()
return isinstance(typ, types.TMono) and \
typ.name in ('int', 'float')
def is_list(typ, elt=None):
if elt:
return types.is_mono(typ, "list", {"elt": elt})
else:
return types.is_mono(typ, "list")
def is_collection(typ):
typ = typ.find()
return isinstance(typ, types.TTuple) or \
types.is_mono(typ, "list")

View File

@ -101,6 +101,36 @@ class TMono(Type):
def __ne__(self, other):
return not (self == other)
class TTuple(Type):
"""A tuple type."""
attributes = {}
def __init__(self, elts=[]):
self.elts = elts
def find(self):
return self
def unify(self, other):
if isinstance(other, TTuple) and len(self.elts) == len(other.elts):
for selfelt, otherelt in zip(self.elts, other.elts):
selfelt.unify(otherelt)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(self, other)
def __repr__(self):
return "TTuple(%s)" % (", ".join(map(repr, self.elts)))
def __eq__(self, other):
return isinstance(other, TTuple) and \
self.elts == other.elts
def __ne__(self, other):
return not (self == other)
class TValue(Type):
"""
A type-level value (such as the integer denoting width of
@ -131,15 +161,32 @@ class TValue(Type):
def is_var(typ):
return isinstance(typ, TVar)
return isinstance(typ.find(), TVar)
def is_mono(typ, name, **params):
typ = typ.find()
params_match = True
for param in params:
params_match = params_match and typ.params[param] == params[param]
return isinstance(typ, TMono) and \
typ.name == name and params_match
def is_tuple(typ, elts=None):
typ = typ.find()
if elts:
return isinstance(typ, TTuple) and \
elts == typ.elts
else:
return isinstance(typ, TTuple)
def get_value(typ):
typ = typ.find()
if isinstance(typ, TVar):
return None
elif isinstance(typ, TValue):
return typ.value
else:
assert False
class TypePrinter(object):
"""

View File

@ -215,7 +215,7 @@ class ASTTypedRewriter(algorithm.Transformer):
def visit_Tuple(self, node):
node = self.generic_visit(node)
return asttyped.TupleT(type=builtins.TTuple([x.type for x in node.elts]),
return asttyped.TupleT(type=types.TTuple([x.type for x in node.elts]),
elts=node.elts, ctx=node.ctx, loc=node.loc)
def visit_List(self, node):
@ -282,7 +282,6 @@ class ASTTypedRewriter(algorithm.Transformer):
self.engine.process(diag)
# expr
visit_BinOp = visit_unsupported
visit_Call = visit_unsupported
visit_Compare = visit_unsupported
visit_Dict = visit_unsupported
@ -375,16 +374,18 @@ class Inferencer(algorithm.Visitor):
return makenotes
def visit_ListT(self, node):
self.generic_visit(node)
for elt in node.elts:
self._unify(node.type["elt"], elt.type,
node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element"))
def visit_AttributeT(self, node):
self.generic_visit(node)
object_type = node.value.type.find()
if not types.is_var(object_type):
if node.attr in object_type.attributes:
# assumes no free type variables in .attributes
node.type = object_type.attributes[node.attr]
node.type.unify(object_type.attributes[node.attr]) # should never fail
else:
diag = diagnostic.Diagnostic("error",
"type {type} does not have an attribute '{attr}'",
@ -393,48 +394,221 @@ class Inferencer(algorithm.Visitor):
self.engine.process(diag)
def visit_SubscriptT(self, node):
self.generic_visit(node)
# TODO: support more than just lists
self._unify(builtins.TList(node.type), node.value.type,
node.loc, node.value.loc)
def visit_IfExpT(self, node):
self.generic_visit(node)
self._unify(node.body.type, node.orelse.type,
node.body.loc, node.orelse.loc)
node.type = node.body.type
node.type.unify(node.body.type) # should never fail
def visit_BoolOpT(self, node):
self.generic_visit(node)
for value in node.values:
self._unify(node.type, value.type,
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
def visit_UnaryOpT(self, node):
self.generic_visit(node)
operand_type = node.operand.type.find()
if isinstance(node.op, ast.Not):
node.type = builtins.TBool()
node.type.unify(builtins.TBool()) # should never fail
elif isinstance(node.op, ast.Invert):
if builtins.is_int(operand_type):
node.type = operand_type
node.type.unify(operand_type) # should never fail
elif not types.is_var(operand_type):
diag = diagnostic.Diagnostic("error",
"expected ~ operand to be of integer type, not {type}",
"expected '~' operand to be of integer type, not {type}",
{"type": types.TypePrinter().name(operand_type)},
node.operand.loc)
self.engine.process(diag)
else: # UAdd, USub
if builtins.is_numeric(operand_type):
node.type = operand_type
node.type.unify(operand_type) # should never fail
elif not types.is_var(operand_type):
diag = diagnostic.Diagnostic("error",
"expected unary {op} operand to be of numeric type, not {type}",
"expected unary '{op}' operand to be of numeric type, not {type}",
{"op": node.op.loc.source(),
"type": types.TypePrinter().name(operand_type)},
node.operand.loc)
self.engine.process(diag)
def visit_CoerceT(self, node):
self.generic_visit(node)
if builtins.is_numeric(node.type) and builtins.is_numeric(node.expr.type):
pass
else:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"expression that required coercion to {typeb}",
{"typeb": printer.name(node.type)},
node.other_expr.loc)
diag = diagnostic.Diagnostic("error",
"cannot coerce {typea} to {typeb}",
{"typea": printer.name(node.expr.type), "typeb": printer.name(node.type)},
node.loc, notes=[note])
self.engine.process(diag)
def _coerce_one(self, typ, coerced_node, other_node):
if coerced_node.type.find() == typ.find():
return coerced_node
else:
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
loc=coerced_node.loc)
self.visit(node)
return node
def _coerce_numeric(self, return_type, left, right):
# Implements the coercion protocol.
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
if builtins.is_float(left.type) or builtins.is_float(right.type):
typ = builtins.TFloat()
elif builtins.is_int(left.type) or builtins.is_int(right.type):
left_width, right_width = \
builtins.get_int_width(left.type), builtins.get_int_width(left.type)
if left_width and right_width:
typ = builtins.TInt(types.TValue(max(left_width, right_width)))
else:
typ = builtins.TInt()
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
return left, right
else: # conflicting types
printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note",
"expression of type {typea}", {"typea": printer.name(left.type)},
left.loc)
note2 = diagnostic.Diagnostic("note",
"expression of type {typeb}", {"typeb": printer.name(right.type)},
right.loc)
diag = diagnostic.Diagnostic("error",
"cannot coerce {typea} and {typeb} to a common numeric type",
{"typea": printer.name(left.type), "typeb": printer.name(right.type)},
left.loc, [right.loc],
[note1, note2])
self.engine.process(diag)
return left, right
# On 1st invocation, return_type is always a type variable.
# On further invocations, coerce will only ever refine the type,
# so this should never fail.
return_type.unify(typ)
return self._coerce_one(typ, left, other_node=right), \
self._coerce_one(typ, right, other_node=left)
def _order_by_pred(self, pred, left, right):
if pred(left.type):
return left, right
elif pred(right.type):
return right, left
else:
assert False
def visit_BinOpT(self, node):
self.generic_visit(node)
if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor,
ast.LShift, ast.RShift)):
# bitwise operators require integers
for operand in (node.left, node.right):
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
diag = diagnostic.Diagnostic("error",
"expected '{op}' operand to be of integer type, not {type}",
{"op": node.op.loc.source(),
"type": types.TypePrinter().name(operand.type)},
node.op.loc, [operand.loc])
self.engine.process(diag)
return
node.left, node.right = \
self._coerce_numeric(node.type, node.left, node.right)
elif isinstance(node.op, ast.Add):
# add works on numbers and also collections
if builtins.is_collection(node.left.type) or builtins.is_collection(node.right.type):
collection, other = \
self._order_by_pred(builtins.is_collection, node.left, node.right)
if types.is_tuple(collection.type):
pred, kind = types.is_tuple, "tuple"
elif builtins.is_list(collection.type):
pred, kind = builtins.is_list, "list"
else:
assert False
if not pred(other.type):
printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note",
"{kind} of type {typea}",
{"typea": printer.name(collection.type), "kind": kind},
collection.loc)
note2 = diagnostic.Diagnostic("note",
"{typeb}, which cannot be added to a {kind}",
{"typeb": printer.name(other.type), "kind": kind},
other.loc)
diag = diagnostic.Diagnostic("error",
"expected every '+' operand to be a {kind} in this context",
{"kind": kind},
node.op.loc, [other.loc, collection.loc],
[note1, note2])
self.engine.process(diag)
return
if types.is_tuple(collection.type):
# should never fail
node.type.unify(types.TTuple(node.left.type.find().elts +
node.right.type.find().elts))
elif builtins.is_list(collection.type):
self._unify(node.left.type, node.right.type,
node.left.loc, node.right.loc)
node.type.unify(node.left.type) # should never fail
else:
node.left, node.right = \
self._coerce_numeric(node.type, node.left, node.right)
elif isinstance(node.op, ast.Mult):
# mult works on numbers and also number & collection
if types.is_tuple(node.left.type) or types.is_tuple(node.right.type):
tuple_, other = self._order_by_pred(types.is_tuple, node.left, node.right)
diag = diagnostic.Diagnostic("error",
"py2llvm does not support passing tuples to '*'", {},
node.op.loc, [tuple_.loc])
self.engine.process(diag)
elif builtins.is_list(node.left.type) or builtins.is_list(node.right.type):
list_, other = self._order_by_pred(builtins.is_list, node.left, node.right)
if not builtins.is_int(other.type):
printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note",
"list operand of type {typea}",
{"typea": printer.name(list_.type)},
list_.loc)
note2 = diagnostic.Diagnostic("note",
"operand of type {typeb}, which is not a valid repetition amount",
{"typeb": printer.name(other.type)},
other.loc)
diag = diagnostic.Diagnostic("error",
"expected '*' operands to be a list and an integer in this context", {},
node.op.loc, [list_.loc, other.loc],
[note1, note2])
self.engine.process(diag)
return
node.type.unify(list_.type)
else:
node.left, node.right = \
self._coerce_numeric(node.type, node.left, node.right)
elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
# numeric operators work on any kind of number
node.left, node.right = \
self._coerce_numeric(node.type, node.left, node.right)
else: # MatMult
diag = diagnostic.Diagnostic("error",
"operator '{op}' is not supported", {"op": node.op.loc.source()},
node.op.loc)
self.engine.process(diag)
return
def visit_Assign(self, node):
self.generic_visit(node)
if len(node.targets) > 1:
self._unify(builtins.TTuple([x.type for x in node.targets]), node.value.type,
self._unify(types.TTuple([x.type for x in node.targets]), node.value.type,
node.targets[0].loc.join(node.targets[-1].loc), node.value.loc)
else:
self._unify(node.targets[0].type, node.value.type,

View File

@ -0,0 +1,26 @@
# RUN: %python -m artiq.py2llvm.typing %s >%t
# RUN: OutputCheck %s --file-to-check=%t
1 | 2
# CHECK-L: 1:int(width='a):int(width='b) | 2:int(width='c):int(width='b):int(width='b)
1 + 2
# CHECK-L: 1:int(width='d):int(width='e) + 2:int(width='f):int(width='e):int(width='e)
(1,) + (2.0,)
# CHECK-L: (1:int(width='g),):(int(width='g),) + (2.0:float,):(float,):(int(width='g), float)
[1] + [2]
# CHECK-L: [1:int(width='h)]:list(elt=int(width='h)) + [2:int(width='h)]:list(elt=int(width='h)):list(elt=int(width='h))
1 * 2
# CHECK-L: 1:int(width='i):int(width='j) * 2:int(width='k):int(width='j):int(width='j)
[1] * 2
# CHECK-L: [1:int(width='l)]:list(elt=int(width='l)) * 2:int(width='m):list(elt=int(width='l))
1 / 2
# CHECK-L: 1:int(width='n):int(width='o) / 2:int(width='p):int(width='o):int(width='o)
1 + 1.0
# CHECK-L: 1:int(width='q):float + 1.0:float:float

View File

@ -0,0 +1,35 @@
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: ${LINE:+1}: error: expected '<<' operand to be of integer type, not float
1 << 2.0
# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a list in this context
# CHECK-L: ${LINE:+2}: note: list of type list(elt=int(width='a))
# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a list
[1] + 2
# CHECK-L: ${LINE:+1}: error: cannot unify list(elt=int(width='a)) with list(elt=float): int(width='a) is incompatible with float
[1] + [2.0]
# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a tuple in this context
# CHECK-L: ${LINE:+2}: note: tuple of type (int(width='a),)
# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a tuple
(1,) + 2
# CHECK-L: ${LINE:+1}: error: py2llvm does not support passing tuples to '*'
(1,) * 2
# CHECK-L: ${LINE:+3}: error: expected '*' operands to be a list and an integer in this context
# CHECK-L: ${LINE:+2}: note: list operand of type list(elt=int(width='a))
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
[1] * []
# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type
# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a)
# CHECK-L: ${LINE:+1}: note: expression of type NoneType
[] - None
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
[] - 1.0

View File

@ -17,10 +17,10 @@ a = b
# CHECK-L: note: an operand of type int(width='a)
# CHECK-L: note: an operand of type bool
# CHECK-L: ${LINE:+1}: error: expected unary + operand to be of numeric type, not list(elt='a)
# CHECK-L: ${LINE:+1}: error: expected unary '+' operand to be of numeric type, not list(elt='a)
+[]
# CHECK-L: ${LINE:+1}: error: expected ~ operand to be of integer type, not float
# CHECK-L: ${LINE:+1}: error: expected '~' operand to be of integer type, not float
~1.0
# CHECK-L: ${LINE:+1}: error: type int(width='a) does not have an attribute 'x'