forked from M-Labs/artiq
Implement BinOp coercion rules for AugAssign.
This commit is contained in:
parent
7b78e7de67
commit
fe69c5b465
@ -461,7 +461,7 @@ class Inferencer(algorithm.Visitor):
|
||||
self.visit(node)
|
||||
return node
|
||||
|
||||
def _coerce_numeric(self, return_type, left, right):
|
||||
def _coerce_numeric(self, left, right):
|
||||
# Implements the coercion protocol.
|
||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||
if builtins.is_float(left.type) or builtins.is_float(right.type):
|
||||
@ -474,7 +474,7 @@ class Inferencer(algorithm.Visitor):
|
||||
else:
|
||||
typ = builtins.TInt()
|
||||
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
|
||||
return left, right
|
||||
return
|
||||
else: # conflicting types
|
||||
printer = types.TypePrinter()
|
||||
note1 = diagnostic.Diagnostic("note",
|
||||
@ -489,15 +489,9 @@ class Inferencer(algorithm.Visitor):
|
||||
left.loc, [right.loc],
|
||||
[note1, note2])
|
||||
self.engine.process(diag)
|
||||
return left, right
|
||||
return
|
||||
|
||||
# On 1st invocation, return_type is always a type variable.
|
||||
# On further invocations, coerce will only ever refine the type,
|
||||
# so this should never fail.
|
||||
return_type.unify(typ)
|
||||
|
||||
return self._coerce_one(typ, left, other_node=right), \
|
||||
self._coerce_one(typ, right, other_node=left)
|
||||
return typ, typ, typ
|
||||
|
||||
def _order_by_pred(self, pred, left, right):
|
||||
if pred(left.type):
|
||||
@ -507,28 +501,26 @@ class Inferencer(algorithm.Visitor):
|
||||
else:
|
||||
assert False
|
||||
|
||||
def visit_BinOpT(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||
def _coerce_binop(self, op, left, right):
|
||||
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||
ast.LShift, ast.RShift)):
|
||||
# bitwise operators require integers
|
||||
for operand in (node.left, node.right):
|
||||
for operand in (left, right):
|
||||
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected '{op}' operand to be of integer type, not {type}",
|
||||
{"op": node.op.loc.source(),
|
||||
{"op": op.loc.source(),
|
||||
"type": types.TypePrinter().name(operand.type)},
|
||||
node.op.loc, [operand.loc])
|
||||
op.loc, [operand.loc])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
elif isinstance(node.op, ast.Add):
|
||||
return self._coerce_numeric(left, right)
|
||||
elif isinstance(op, ast.Add):
|
||||
# add works on numbers and also collections
|
||||
if builtins.is_collection(node.left.type) or builtins.is_collection(node.right.type):
|
||||
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
|
||||
collection, other = \
|
||||
self._order_by_pred(builtins.is_collection, node.left, node.right)
|
||||
self._order_by_pred(builtins.is_collection, left, right)
|
||||
if types.is_tuple(collection.type):
|
||||
pred, kind = types.is_tuple, "tuple"
|
||||
elif builtins.is_list(collection.type):
|
||||
@ -548,32 +540,32 @@ class Inferencer(algorithm.Visitor):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected every '+' operand to be a {kind} in this context",
|
||||
{"kind": kind},
|
||||
node.op.loc, [other.loc, collection.loc],
|
||||
op.loc, [other.loc, collection.loc],
|
||||
[note1, note2])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
if types.is_tuple(collection.type):
|
||||
# should never fail
|
||||
node.type.unify(types.TTuple(node.left.type.find().elts +
|
||||
node.right.type.find().elts))
|
||||
return types.TTuple(left.type.find().elts +
|
||||
right.type.find().elts), left.type, right.type
|
||||
elif builtins.is_list(collection.type):
|
||||
self._unify(node.left.type, node.right.type,
|
||||
node.left.loc, node.right.loc)
|
||||
node.type.unify(node.left.type) # should never fail
|
||||
self._unify(left.type, right.type,
|
||||
left.loc, right.loc)
|
||||
return left.type, left.type, right.type
|
||||
else:
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
elif isinstance(node.op, ast.Mult):
|
||||
return self._coerce_numeric(left, right)
|
||||
elif isinstance(op, ast.Mult):
|
||||
# mult works on numbers and also number & collection
|
||||
if types.is_tuple(node.left.type) or types.is_tuple(node.right.type):
|
||||
tuple_, other = self._order_by_pred(types.is_tuple, node.left, node.right)
|
||||
if types.is_tuple(left.type) or types.is_tuple(right.type):
|
||||
tuple_, other = self._order_by_pred(types.is_tuple, left, right)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"py2llvm does not support passing tuples to '*'", {},
|
||||
node.op.loc, [tuple_.loc])
|
||||
op.loc, [tuple_.loc])
|
||||
self.engine.process(diag)
|
||||
elif builtins.is_list(node.left.type) or builtins.is_list(node.right.type):
|
||||
list_, other = self._order_by_pred(builtins.is_list, node.left, node.right)
|
||||
return
|
||||
elif builtins.is_list(left.type) or builtins.is_list(right.type):
|
||||
list_, other = self._order_by_pred(builtins.is_list, left, right)
|
||||
if not builtins.is_int(other.type):
|
||||
printer = types.TypePrinter()
|
||||
note1 = diagnostic.Diagnostic("note",
|
||||
@ -586,25 +578,33 @@ class Inferencer(algorithm.Visitor):
|
||||
other.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expected '*' operands to be a list and an integer in this context", {},
|
||||
node.op.loc, [list_.loc, other.loc],
|
||||
op.loc, [list_.loc, other.loc],
|
||||
[note1, note2])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
node.type.unify(list_.type)
|
||||
|
||||
return list_.type, left.type, right.type
|
||||
else:
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
||||
return self._coerce_numeric(left, right)
|
||||
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
||||
# numeric operators work on any kind of number
|
||||
node.left, node.right = \
|
||||
self._coerce_numeric(node.type, node.left, node.right)
|
||||
return self._coerce_numeric(left, right)
|
||||
else: # MatMult
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"operator '{op}' is not supported", {"op": node.op.loc.source()},
|
||||
node.op.loc)
|
||||
"operator '{op}' is not supported", {"op": op.loc.source()},
|
||||
op.loc)
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
def visit_BinOpT(self, node):
|
||||
self.generic_visit(node)
|
||||
coerced = self._coerce_binop(node.op, node.left, node.right)
|
||||
if coerced:
|
||||
return_type, left_type, right_type = coerced
|
||||
node.left = self._coerce_one(left_type, node.left, other_node=node.right)
|
||||
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
||||
node.type.unify(return_type) # should never fail
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
if len(node.targets) > 1:
|
||||
@ -616,8 +616,45 @@ class Inferencer(algorithm.Visitor):
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
self.generic_visit(node)
|
||||
self._unify(node.target.type, node.value.type,
|
||||
node.target.loc, node.value.loc)
|
||||
coerced = self._coerce_binop(node.op, node.target, node.value)
|
||||
if coerced:
|
||||
return_type, target_type, value_type = coerced
|
||||
|
||||
try:
|
||||
node.target.type.unify(target_type)
|
||||
except types.UnificationError as e:
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"expression of type {typec}",
|
||||
{"typec": printer.name(node.value.type)},
|
||||
node.value.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"expression of type {typea} has to be coerced to {typeb}, "
|
||||
"which makes assignment invalid",
|
||||
{"typea": printer.name(node.target.type),
|
||||
"typeb": printer.name(target_type)},
|
||||
node.op.loc, [node.target.loc], [note])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
try:
|
||||
node.target.type.unify(return_type)
|
||||
except types.UnificationError as e:
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"expression of type {typec}",
|
||||
{"typec": printer.name(node.value.type)},
|
||||
node.value.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"the result of this operation has type {typeb}, "
|
||||
"which makes assignment to a slot of type {typea} invalid",
|
||||
{"typea": printer.name(node.target.type),
|
||||
"typeb": printer.name(return_type)},
|
||||
node.op.loc, [node.target.loc], [note])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
|
||||
|
||||
def visit_For(self, node):
|
||||
old_in_loop, self.in_loop = self.in_loop, True
|
||||
|
@ -24,3 +24,6 @@
|
||||
|
||||
1 + 1.0
|
||||
# CHECK-L: 1:int(width='q):float + 1.0:float:float
|
||||
|
||||
a = []; a += [1]
|
||||
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
|
||||
|
@ -33,3 +33,11 @@
|
||||
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
|
||||
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
|
||||
[] - 1.0
|
||||
|
||||
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid
|
||||
# CHECK-L: ${LINE:+1}: note: expression of type float
|
||||
a = 1; a += 1.0
|
||||
|
||||
# CHECK-L: ${LINE:+2}: error: the result of this operation has type (int(width='a), float), which makes assignment to a slot of type (int(width='a),) invalid
|
||||
# CHECK-L: ${LINE:+1}: note: expression of type (float,)
|
||||
b = (1,); b += (1.0,)
|
||||
|
Loading…
Reference in New Issue
Block a user