mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-25 11:18:27 +08:00
Add support for BinOp.
This commit is contained in:
parent
faaf189961
commit
7b78e7de67
@ -22,6 +22,7 @@ class scoped(object):
|
||||
list of variables resolved as globals
|
||||
"""
|
||||
|
||||
# Typed versions of untyped nodes
|
||||
class argT(ast.arg, commontyped):
|
||||
pass
|
||||
|
||||
@ -82,3 +83,7 @@ class YieldT(ast.Yield, commontyped):
|
||||
pass
|
||||
class YieldFromT(ast.YieldFrom, commontyped):
|
||||
pass
|
||||
|
||||
# Novel typed nodes
|
||||
class CoerceT(ast.expr, commontyped):
|
||||
_fields = ('expr',) # other_expr deliberately not in _fields
|
||||
|
@ -23,36 +23,6 @@ class TFloat(types.TMono):
|
||||
def __init__(self):
|
||||
super().__init__("float")
|
||||
|
||||
class TTuple(types.Type):
|
||||
"""A tuple type."""
|
||||
|
||||
attributes = {}
|
||||
|
||||
def __init__(self, elts=[]):
|
||||
self.elts = elts
|
||||
|
||||
def find(self):
|
||||
return self
|
||||
|
||||
def unify(self, other):
|
||||
if isinstance(other, TTuple) and len(self.elts) == len(other.elts):
|
||||
for selfelt, otherelt in zip(self.elts, other.elts):
|
||||
selfelt.unify(otherelt)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
return "TTuple(%s)" % (", ".join(map(repr, self.elts)))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TTuple) and \
|
||||
self.elts == other.elts
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TList(types.TMono):
|
||||
def __init__(self, elt=None):
|
||||
if elt is None:
|
||||
@ -60,12 +30,37 @@ class TList(types.TMono):
|
||||
super().__init__("list", {"elt": elt})
|
||||
|
||||
|
||||
def is_none(typ):
|
||||
return types.is_mono(typ, "NoneType")
|
||||
|
||||
def is_bool(typ):
|
||||
return types.is_mono(typ, "bool")
|
||||
|
||||
def is_int(typ, width=None):
|
||||
if width:
|
||||
return types.is_mono(typ, "int", {"width": width})
|
||||
else:
|
||||
return types.is_mono(typ, "int")
|
||||
|
||||
def get_int_width(typ):
|
||||
if is_int(typ):
|
||||
return types.get_value(typ["width"])
|
||||
|
||||
def is_float(typ):
|
||||
return types.is_mono(typ, "float")
|
||||
|
||||
def is_numeric(typ):
|
||||
typ = typ.find()
|
||||
return isinstance(typ, types.TMono) and \
|
||||
typ.name in ('int', 'float')
|
||||
|
||||
def is_list(typ, elt=None):
|
||||
if elt:
|
||||
return types.is_mono(typ, "list", {"elt": elt})
|
||||
else:
|
||||
return types.is_mono(typ, "list")
|
||||
|
||||
def is_collection(typ):
|
||||
typ = typ.find()
|
||||
return isinstance(typ, types.TTuple) or \
|
||||
types.is_mono(typ, "list")
|
||||
|
@ -101,6 +101,36 @@ class TMono(Type):
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TTuple(Type):
|
||||
"""A tuple type."""
|
||||
|
||||
attributes = {}
|
||||
|
||||
def __init__(self, elts=[]):
|
||||
self.elts = elts
|
||||
|
||||
def find(self):
|
||||
return self
|
||||
|
||||
def unify(self, other):
|
||||
if isinstance(other, TTuple) and len(self.elts) == len(other.elts):
|
||||
for selfelt, otherelt in zip(self.elts, other.elts):
|
||||
selfelt.unify(otherelt)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
raise UnificationError(self, other)
|
||||
|
||||
def __repr__(self):
|
||||
return "TTuple(%s)" % (", ".join(map(repr, self.elts)))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TTuple) and \
|
||||
self.elts == other.elts
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
class TValue(Type):
|
||||
"""
|
||||
A type-level value (such as the integer denoting width of
|
||||
@ -131,15 +161,32 @@ class TValue(Type):
|
||||
|
||||
|
||||
def is_var(typ):
|
||||
return isinstance(typ, TVar)
|
||||
return isinstance(typ.find(), TVar)
|
||||
|
||||
def is_mono(typ, name, **params):
|
||||
typ = typ.find()
|
||||
params_match = True
|
||||
for param in params:
|
||||
params_match = params_match and typ.params[param] == params[param]
|
||||
return isinstance(typ, TMono) and \
|
||||
typ.name == name and params_match
|
||||
|
||||
def is_tuple(typ, elts=None):
|
||||
typ = typ.find()
|
||||
if elts:
|
||||
return isinstance(typ, TTuple) and \
|
||||
elts == typ.elts
|
||||
else:
|
||||
return isinstance(typ, TTuple)
|
||||
|
||||
def get_value(typ):
|
||||
typ = typ.find()
|
||||
if isinstance(typ, TVar):
|
||||
return None
|
||||
elif isinstance(typ, TValue):
|
||||
return typ.value
|
||||
else:
|
||||
assert False
|
||||
|
||||
class TypePrinter(object):
|
||||
"""
|
||||
|
@ -215,7 +215,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
node = self.generic_visit(node)
|
||||
return asttyped.TupleT(type=builtins.TTuple([x.type for x in node.elts]),
|
||||
return asttyped.TupleT(type=types.TTuple([x.type for x in node.elts]),
|
||||
elts=node.elts, ctx=node.ctx, loc=node.loc)
|
||||
|
||||
def visit_List(self, node):
|
||||
@ -282,7 +282,6 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
self.engine.process(diag)
|
||||
|
||||
# expr
|
||||
visit_BinOp = visit_unsupported
|
||||
visit_Call = visit_unsupported
|
||||
visit_Compare = visit_unsupported
|
||||
visit_Dict = visit_unsupported
|
||||
@ -375,16 +374,18 @@ class Inferencer(algorithm.Visitor):
|
||||
return makenotes
|
||||
|
||||
def visit_ListT(self, node):
|
||||
self.generic_visit(node)
|
||||
for elt in node.elts:
|
||||
self._unify(node.type["elt"], elt.type,
|
||||
node.loc, elt.loc, self._makenotes_elts(node.elts, "a list element"))
|
||||
|
||||
def visit_AttributeT(self, node):
|
||||
self.generic_visit(node)
|
||||
object_type = node.value.type.find()
|
||||
if not types.is_var(object_type):
|
||||
if node.attr in object_type.attributes:
|
||||
# assumes no free type variables in .attributes
|
||||
node.type = object_type.attributes[node.attr]
|
||||
node.type.unify(object_type.attributes[node.attr]) # should never fail
|
||||
else:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type {type} does not have an attribute '{attr}'",
|
||||
@ -393,48 +394,221 @@ class Inferencer(algorithm.Visitor):
|
||||
self.engine.process(diag)
|
||||
|
||||
def visit_SubscriptT(self, node):
|
||||
self.generic_visit(node)
|
||||
# TODO: support more than just lists
|
||||
self._unify(builtins.TList(node.type), node.value.type,
|
||||
node.loc, node.value.loc)
|
||||
|
||||
def visit_IfExpT(self, node):
|
||||
self.generic_visit(node)
|
||||
self._unify(node.body.type, node.orelse.type,
|
||||
node.body.loc, node.orelse.loc)
|
||||
node.type = node.body.type
|
||||
node.type.unify(node.body.type) # should never fail
|
||||
|
||||
def visit_BoolOpT(self, node):
|
||||
self.generic_visit(node)
|
||||
for value in node.values:
|
||||
self._unify(node.type, value.type,
|
||||
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
|
||||
|
||||
def visit_UnaryOpT(self, node):
|
||||
self.generic_visit(node)
|
||||
operand_type = node.operand.type.find()
|
||||
if isinstance(node.op, ast.Not):
|
||||
node.type = builtins.TBool()
|
||||
node.type.unify(builtins.TBool()) # should never fail
|
||||
elif isinstance(node.op, ast.Invert):
|
||||
if builtins.is_int(operand_type):
|
||||
node.type = operand_type
|
||||
node.type.unify(operand_type) # should never fail
|
||||
elif not types.is_var(operand_type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected ~ operand to be of integer type, not {type}",
|
||||
"expected '~' operand to be of integer type, not {type}",
|
||||
{"type": types.TypePrinter().name(operand_type)},
|
||||
node.operand.loc)
|
||||
self.engine.process(diag)
|
||||
else: # UAdd, USub
|
||||
if builtins.is_numeric(operand_type):
|
||||
node.type = operand_type
|
||||
node.type.unify(operand_type) # should never fail
|
||||
elif not types.is_var(operand_type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected unary {op} operand to be of numeric type, not {type}",
|
||||
"expected unary '{op}' operand to be of numeric type, not {type}",
|
||||
{"op": node.op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand_type)},
|
||||
node.operand.loc)
|
||||
self.engine.process(diag)
|
||||
|
||||
def visit_CoerceT(self, node):
|
||||
self.generic_visit(node)
|
||||
if builtins.is_numeric(node.type) and builtins.is_numeric(node.expr.type):
|
||||
pass
|
||||
else:
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"expression that required coercion to {typeb}",
|
||||
{"typeb": printer.name(node.type)},
|
||||
node.other_expr.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"cannot coerce {typea} to {typeb}",
|
||||
{"typea": printer.name(node.expr.type), "typeb": printer.name(node.type)},
|
||||
node.loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
|
||||
def _coerce_one(self, typ, coerced_node, other_node):
|
||||
if coerced_node.type.find() == typ.find():
|
||||
return coerced_node
|
||||
else:
|
||||
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
|
||||
loc=coerced_node.loc)
|
||||
self.visit(node)
|
||||
return node
|
||||
|
||||
def _coerce_numeric(self, return_type, left, right):
|
||||
# Implements the coercion protocol.
|
||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||
if builtins.is_float(left.type) or builtins.is_float(right.type):
|
||||
typ = builtins.TFloat()
|
||||
elif builtins.is_int(left.type) or builtins.is_int(right.type):
|
||||
left_width, right_width = \
|
||||
builtins.get_int_width(left.type), builtins.get_int_width(left.type)
|
||||
if left_width and right_width:
|
||||
typ = builtins.TInt(types.TValue(max(left_width, right_width)))
|
||||
else:
|
||||
typ = builtins.TInt()
|
||||
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
|
||||
return left, right
|
||||
else: # conflicting types
|
||||
printer = types.TypePrinter()
|
||||
note1 = diagnostic.Diagnostic("note",
|
||||
"expression of type {typea}", {"typea": printer.name(left.type)},
|
||||
left.loc)
|
||||
note2 = diagnostic.Diagnostic("note",
|
||||
"expression of type {typeb}", {"typeb": printer.name(right.type)},
|
||||
right.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"cannot coerce {typea} and {typeb} to a common numeric type",
|
||||
{"typea": printer.name(left.type), "typeb": printer.name(right.type)},
|
||||
left.loc, [right.loc],
|
||||
[note1, note2])
|
||||
self.engine.process(diag)
|
||||
return left, right
|
||||
|
||||
# On 1st invocation, return_type is always a type variable.
|
||||
# On further invocations, coerce will only ever refine the type,
|
||||
# so this should never fail.
|
||||
return_type.unify(typ)
|
||||
|
||||
return self._coerce_one(typ, left, other_node=right), \
|
||||
self._coerce_one(typ, right, other_node=left)
|
||||
|
||||
def _order_by_pred(self, pred, left, right):
|
||||
if pred(left.type):
|
||||
return left, right
|
||||
elif pred(right.type):
|
||||
return right, left
|
||||
else:
|
||||
assert False
|
||||
|
||||
def visit_BinOpT(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||
ast.LShift, ast.RShift)):
|
||||
# bitwise operators require integers
|
||||
for operand in (node.left, node.right):
|
||||
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected '{op}' operand to be of integer type, not {type}",
|
||||
{"op": node.op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand.type)},
|
||||
node.op.loc, [operand.loc])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
elif isinstance(node.op, ast.Add):
|
||||
# add works on numbers and also collections
|
||||
if builtins.is_collection(node.left.type) or builtins.is_collection(node.right.type):
|
||||
collection, other = \
|
||||
self._order_by_pred(builtins.is_collection, node.left, node.right)
|
||||
if types.is_tuple(collection.type):
|
||||
pred, kind = types.is_tuple, "tuple"
|
||||
elif builtins.is_list(collection.type):
|
||||
pred, kind = builtins.is_list, "list"
|
||||
else:
|
||||
assert False
|
||||
if not pred(other.type):
|
||||
printer = types.TypePrinter()
|
||||
note1 = diagnostic.Diagnostic("note",
|
||||
"{kind} of type {typea}",
|
||||
{"typea": printer.name(collection.type), "kind": kind},
|
||||
collection.loc)
|
||||
note2 = diagnostic.Diagnostic("note",
|
||||
"{typeb}, which cannot be added to a {kind}",
|
||||
{"typeb": printer.name(other.type), "kind": kind},
|
||||
other.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected every '+' operand to be a {kind} in this context",
|
||||
{"kind": kind},
|
||||
node.op.loc, [other.loc, collection.loc],
|
||||
[note1, note2])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
if types.is_tuple(collection.type):
|
||||
# should never fail
|
||||
node.type.unify(types.TTuple(node.left.type.find().elts +
|
||||
node.right.type.find().elts))
|
||||
elif builtins.is_list(collection.type):
|
||||
self._unify(node.left.type, node.right.type,
|
||||
node.left.loc, node.right.loc)
|
||||
node.type.unify(node.left.type) # should never fail
|
||||
else:
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
elif isinstance(node.op, ast.Mult):
|
||||
# mult works on numbers and also number & collection
|
||||
if types.is_tuple(node.left.type) or types.is_tuple(node.right.type):
|
||||
tuple_, other = self._order_by_pred(types.is_tuple, node.left, node.right)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"py2llvm does not support passing tuples to '*'", {},
|
||||
node.op.loc, [tuple_.loc])
|
||||
self.engine.process(diag)
|
||||
elif builtins.is_list(node.left.type) or builtins.is_list(node.right.type):
|
||||
list_, other = self._order_by_pred(builtins.is_list, node.left, node.right)
|
||||
if not builtins.is_int(other.type):
|
||||
printer = types.TypePrinter()
|
||||
note1 = diagnostic.Diagnostic("note",
|
||||
"list operand of type {typea}",
|
||||
{"typea": printer.name(list_.type)},
|
||||
list_.loc)
|
||||
note2 = diagnostic.Diagnostic("note",
|
||||
"operand of type {typeb}, which is not a valid repetition amount",
|
||||
{"typeb": printer.name(other.type)},
|
||||
other.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected '*' operands to be a list and an integer in this context", {},
|
||||
node.op.loc, [list_.loc, other.loc],
|
||||
[note1, note2])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
node.type.unify(list_.type)
|
||||
else:
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
||||
# numeric operators work on any kind of number
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
else: # MatMult
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"operator '{op}' is not supported", {"op": node.op.loc.source()},
|
||||
node.op.loc)
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
if len(node.targets) > 1:
|
||||
self._unify(builtins.TTuple([x.type for x in node.targets]), node.value.type,
|
||||
self._unify(types.TTuple([x.type for x in node.targets]), node.value.type,
|
||||
node.targets[0].loc.join(node.targets[-1].loc), node.value.loc)
|
||||
else:
|
||||
self._unify(node.targets[0].type, node.value.type,
|
||||
|
26
lit-test/py2llvm/typing/coerce.py
Normal file
26
lit-test/py2llvm/typing/coerce.py
Normal file
@ -0,0 +1,26 @@
|
||||
# RUN: %python -m artiq.py2llvm.typing %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
1 | 2
|
||||
# CHECK-L: 1:int(width='a):int(width='b) | 2:int(width='c):int(width='b):int(width='b)
|
||||
|
||||
1 + 2
|
||||
# CHECK-L: 1:int(width='d):int(width='e) + 2:int(width='f):int(width='e):int(width='e)
|
||||
|
||||
(1,) + (2.0,)
|
||||
# CHECK-L: (1:int(width='g),):(int(width='g),) + (2.0:float,):(float,):(int(width='g), float)
|
||||
|
||||
[1] + [2]
|
||||
# CHECK-L: [1:int(width='h)]:list(elt=int(width='h)) + [2:int(width='h)]:list(elt=int(width='h)):list(elt=int(width='h))
|
||||
|
||||
1 * 2
|
||||
# CHECK-L: 1:int(width='i):int(width='j) * 2:int(width='k):int(width='j):int(width='j)
|
||||
|
||||
[1] * 2
|
||||
# CHECK-L: [1:int(width='l)]:list(elt=int(width='l)) * 2:int(width='m):list(elt=int(width='l))
|
||||
|
||||
1 / 2
|
||||
# CHECK-L: 1:int(width='n):int(width='o) / 2:int(width='p):int(width='o):int(width='o)
|
||||
|
||||
1 + 1.0
|
||||
# CHECK-L: 1:int(width='q):float + 1.0:float:float
|
35
lit-test/py2llvm/typing/error_coerce.py
Normal file
35
lit-test/py2llvm/typing/error_coerce.py
Normal file
@ -0,0 +1,35 @@
|
||||
# RUN: %python -m artiq.py2llvm.typing +diag %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: expected '<<' operand to be of integer type, not float
|
||||
1 << 2.0
|
||||
|
||||
# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a list in this context
|
||||
# CHECK-L: ${LINE:+2}: note: list of type list(elt=int(width='a))
|
||||
# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a list
|
||||
[1] + 2
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: cannot unify list(elt=int(width='a)) with list(elt=float): int(width='a) is incompatible with float
|
||||
[1] + [2.0]
|
||||
|
||||
# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a tuple in this context
|
||||
# CHECK-L: ${LINE:+2}: note: tuple of type (int(width='a),)
|
||||
# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a tuple
|
||||
(1,) + 2
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: py2llvm does not support passing tuples to '*'
|
||||
(1,) * 2
|
||||
|
||||
# CHECK-L: ${LINE:+3}: error: expected '*' operands to be a list and an integer in this context
|
||||
# CHECK-L: ${LINE:+2}: note: list operand of type list(elt=int(width='a))
|
||||
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
|
||||
[1] * []
|
||||
|
||||
# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type
|
||||
# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a)
|
||||
# CHECK-L: ${LINE:+1}: note: expression of type NoneType
|
||||
[] - None
|
||||
|
||||
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
|
||||
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
|
||||
[] - 1.0
|
@ -17,10 +17,10 @@ a = b
|
||||
# CHECK-L: note: an operand of type int(width='a)
|
||||
# CHECK-L: note: an operand of type bool
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: expected unary + operand to be of numeric type, not list(elt='a)
|
||||
# CHECK-L: ${LINE:+1}: error: expected unary '+' operand to be of numeric type, not list(elt='a)
|
||||
+[]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: expected ~ operand to be of integer type, not float
|
||||
# CHECK-L: ${LINE:+1}: error: expected '~' operand to be of integer type, not float
|
||||
~1.0
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: type int(width='a) does not have an attribute 'x'
|
||||
|
Loading…
Reference in New Issue
Block a user