forked from M-Labs/artiq
Implement BinOp coercion rules for AugAssign.
This commit is contained in:
parent
7b78e7de67
commit
fe69c5b465
|
@ -461,7 +461,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.visit(node)
|
self.visit(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def _coerce_numeric(self, return_type, left, right):
|
def _coerce_numeric(self, left, right):
|
||||||
# Implements the coercion protocol.
|
# Implements the coercion protocol.
|
||||||
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
||||||
if builtins.is_float(left.type) or builtins.is_float(right.type):
|
if builtins.is_float(left.type) or builtins.is_float(right.type):
|
||||||
|
@ -474,7 +474,7 @@ class Inferencer(algorithm.Visitor):
|
||||||
else:
|
else:
|
||||||
typ = builtins.TInt()
|
typ = builtins.TInt()
|
||||||
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
|
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
|
||||||
return left, right
|
return
|
||||||
else: # conflicting types
|
else: # conflicting types
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
note1 = diagnostic.Diagnostic("note",
|
note1 = diagnostic.Diagnostic("note",
|
||||||
|
@ -489,15 +489,9 @@ class Inferencer(algorithm.Visitor):
|
||||||
left.loc, [right.loc],
|
left.loc, [right.loc],
|
||||||
[note1, note2])
|
[note1, note2])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return left, right
|
return
|
||||||
|
|
||||||
# On 1st invocation, return_type is always a type variable.
|
return typ, typ, typ
|
||||||
# On further invocations, coerce will only ever refine the type,
|
|
||||||
# so this should never fail.
|
|
||||||
return_type.unify(typ)
|
|
||||||
|
|
||||||
return self._coerce_one(typ, left, other_node=right), \
|
|
||||||
self._coerce_one(typ, right, other_node=left)
|
|
||||||
|
|
||||||
def _order_by_pred(self, pred, left, right):
|
def _order_by_pred(self, pred, left, right):
|
||||||
if pred(left.type):
|
if pred(left.type):
|
||||||
|
@ -507,28 +501,26 @@ class Inferencer(algorithm.Visitor):
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
def visit_BinOpT(self, node):
|
def _coerce_binop(self, op, left, right):
|
||||||
self.generic_visit(node)
|
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
||||||
if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
|
||||||
ast.LShift, ast.RShift)):
|
ast.LShift, ast.RShift)):
|
||||||
# bitwise operators require integers
|
# bitwise operators require integers
|
||||||
for operand in (node.left, node.right):
|
for operand in (left, right):
|
||||||
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"expected '{op}' operand to be of integer type, not {type}",
|
"expected '{op}' operand to be of integer type, not {type}",
|
||||||
{"op": node.op.loc.source(),
|
{"op": op.loc.source(),
|
||||||
"type": types.TypePrinter().name(operand.type)},
|
"type": types.TypePrinter().name(operand.type)},
|
||||||
node.op.loc, [operand.loc])
|
op.loc, [operand.loc])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
|
||||||
node.left, node.right = \
|
return self._coerce_numeric(left, right)
|
||||||
self._coerce_numeric(node.type, node.left, node.right)
|
elif isinstance(op, ast.Add):
|
||||||
elif isinstance(node.op, ast.Add):
|
|
||||||
# add works on numbers and also collections
|
# add works on numbers and also collections
|
||||||
if builtins.is_collection(node.left.type) or builtins.is_collection(node.right.type):
|
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
|
||||||
collection, other = \
|
collection, other = \
|
||||||
self._order_by_pred(builtins.is_collection, node.left, node.right)
|
self._order_by_pred(builtins.is_collection, left, right)
|
||||||
if types.is_tuple(collection.type):
|
if types.is_tuple(collection.type):
|
||||||
pred, kind = types.is_tuple, "tuple"
|
pred, kind = types.is_tuple, "tuple"
|
||||||
elif builtins.is_list(collection.type):
|
elif builtins.is_list(collection.type):
|
||||||
|
@ -548,32 +540,32 @@ class Inferencer(algorithm.Visitor):
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"expected every '+' operand to be a {kind} in this context",
|
"expected every '+' operand to be a {kind} in this context",
|
||||||
{"kind": kind},
|
{"kind": kind},
|
||||||
node.op.loc, [other.loc, collection.loc],
|
op.loc, [other.loc, collection.loc],
|
||||||
[note1, note2])
|
[note1, note2])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
|
||||||
if types.is_tuple(collection.type):
|
if types.is_tuple(collection.type):
|
||||||
# should never fail
|
# should never fail
|
||||||
node.type.unify(types.TTuple(node.left.type.find().elts +
|
return types.TTuple(left.type.find().elts +
|
||||||
node.right.type.find().elts))
|
right.type.find().elts), left.type, right.type
|
||||||
elif builtins.is_list(collection.type):
|
elif builtins.is_list(collection.type):
|
||||||
self._unify(node.left.type, node.right.type,
|
self._unify(left.type, right.type,
|
||||||
node.left.loc, node.right.loc)
|
left.loc, right.loc)
|
||||||
node.type.unify(node.left.type) # should never fail
|
return left.type, left.type, right.type
|
||||||
else:
|
else:
|
||||||
node.left, node.right = \
|
return self._coerce_numeric(left, right)
|
||||||
self._coerce_numeric(node.type, node.left, node.right)
|
elif isinstance(op, ast.Mult):
|
||||||
elif isinstance(node.op, ast.Mult):
|
|
||||||
# mult works on numbers and also number & collection
|
# mult works on numbers and also number & collection
|
||||||
if types.is_tuple(node.left.type) or types.is_tuple(node.right.type):
|
if types.is_tuple(left.type) or types.is_tuple(right.type):
|
||||||
tuple_, other = self._order_by_pred(types.is_tuple, node.left, node.right)
|
tuple_, other = self._order_by_pred(types.is_tuple, left, right)
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"py2llvm does not support passing tuples to '*'", {},
|
"py2llvm does not support passing tuples to '*'", {},
|
||||||
node.op.loc, [tuple_.loc])
|
op.loc, [tuple_.loc])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
elif builtins.is_list(node.left.type) or builtins.is_list(node.right.type):
|
return
|
||||||
list_, other = self._order_by_pred(builtins.is_list, node.left, node.right)
|
elif builtins.is_list(left.type) or builtins.is_list(right.type):
|
||||||
|
list_, other = self._order_by_pred(builtins.is_list, left, right)
|
||||||
if not builtins.is_int(other.type):
|
if not builtins.is_int(other.type):
|
||||||
printer = types.TypePrinter()
|
printer = types.TypePrinter()
|
||||||
note1 = diagnostic.Diagnostic("note",
|
note1 = diagnostic.Diagnostic("note",
|
||||||
|
@ -586,25 +578,33 @@ class Inferencer(algorithm.Visitor):
|
||||||
other.loc)
|
other.loc)
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"expected '*' operands to be a list and an integer in this context", {},
|
"expected '*' operands to be a list and an integer in this context", {},
|
||||||
node.op.loc, [list_.loc, other.loc],
|
op.loc, [list_.loc, other.loc],
|
||||||
[note1, note2])
|
[note1, note2])
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
node.type.unify(list_.type)
|
|
||||||
|
return list_.type, left.type, right.type
|
||||||
else:
|
else:
|
||||||
node.left, node.right = \
|
return self._coerce_numeric(left, right)
|
||||||
self._coerce_numeric(node.type, node.left, node.right)
|
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
||||||
elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
|
||||||
# numeric operators work on any kind of number
|
# numeric operators work on any kind of number
|
||||||
node.left, node.right = \
|
return self._coerce_numeric(left, right)
|
||||||
self._coerce_numeric(node.type, node.left, node.right)
|
|
||||||
else: # MatMult
|
else: # MatMult
|
||||||
diag = diagnostic.Diagnostic("error",
|
diag = diagnostic.Diagnostic("error",
|
||||||
"operator '{op}' is not supported", {"op": node.op.loc.source()},
|
"operator '{op}' is not supported", {"op": op.loc.source()},
|
||||||
node.op.loc)
|
op.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def visit_BinOpT(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
coerced = self._coerce_binop(node.op, node.left, node.right)
|
||||||
|
if coerced:
|
||||||
|
return_type, left_type, right_type = coerced
|
||||||
|
node.left = self._coerce_one(left_type, node.left, other_node=node.right)
|
||||||
|
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
||||||
|
node.type.unify(return_type) # should never fail
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if len(node.targets) > 1:
|
if len(node.targets) > 1:
|
||||||
|
@ -616,8 +616,45 @@ class Inferencer(algorithm.Visitor):
|
||||||
|
|
||||||
def visit_AugAssign(self, node):
|
def visit_AugAssign(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
self._unify(node.target.type, node.value.type,
|
coerced = self._coerce_binop(node.op, node.target, node.value)
|
||||||
node.target.loc, node.value.loc)
|
if coerced:
|
||||||
|
return_type, target_type, value_type = coerced
|
||||||
|
|
||||||
|
try:
|
||||||
|
node.target.type.unify(target_type)
|
||||||
|
except types.UnificationError as e:
|
||||||
|
printer = types.TypePrinter()
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typec}",
|
||||||
|
{"typec": printer.name(node.value.type)},
|
||||||
|
node.value.loc)
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"expression of type {typea} has to be coerced to {typeb}, "
|
||||||
|
"which makes assignment invalid",
|
||||||
|
{"typea": printer.name(node.target.type),
|
||||||
|
"typeb": printer.name(target_type)},
|
||||||
|
node.op.loc, [node.target.loc], [note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
node.target.type.unify(return_type)
|
||||||
|
except types.UnificationError as e:
|
||||||
|
printer = types.TypePrinter()
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"expression of type {typec}",
|
||||||
|
{"typec": printer.name(node.value.type)},
|
||||||
|
node.value.loc)
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"the result of this operation has type {typeb}, "
|
||||||
|
"which makes assignment to a slot of type {typea} invalid",
|
||||||
|
{"typea": printer.name(node.target.type),
|
||||||
|
"typeb": printer.name(return_type)},
|
||||||
|
node.op.loc, [node.target.loc], [note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
||||||
|
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_For(self, node):
|
||||||
old_in_loop, self.in_loop = self.in_loop, True
|
old_in_loop, self.in_loop = self.in_loop, True
|
||||||
|
|
|
@ -24,3 +24,6 @@
|
||||||
|
|
||||||
1 + 1.0
|
1 + 1.0
|
||||||
# CHECK-L: 1:int(width='q):float + 1.0:float:float
|
# CHECK-L: 1:int(width='q):float + 1.0:float:float
|
||||||
|
|
||||||
|
a = []; a += [1]
|
||||||
|
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
|
||||||
|
|
|
@ -33,3 +33,11 @@
|
||||||
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
|
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
|
||||||
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
|
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
|
||||||
[] - 1.0
|
[] - 1.0
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid
|
||||||
|
# CHECK-L: ${LINE:+1}: note: expression of type float
|
||||||
|
a = 1; a += 1.0
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+2}: error: the result of this operation has type (int(width='a), float), which makes assignment to a slot of type (int(width='a),) invalid
|
||||||
|
# CHECK-L: ${LINE:+1}: note: expression of type (float,)
|
||||||
|
b = (1,); b += (1.0,)
|
||||||
|
|
Loading…
Reference in New Issue