forked from M-Labs/artiq
1
0
Fork 0

Implement BinOp coercion rules for AugAssign.

This commit is contained in:
whitequark 2015-06-14 13:10:32 +03:00
parent 7b78e7de67
commit fe69c5b465
3 changed files with 94 additions and 46 deletions

View File

@ -461,7 +461,7 @@ class Inferencer(algorithm.Visitor):
self.visit(node) self.visit(node)
return node return node
def _coerce_numeric(self, return_type, left, right): def _coerce_numeric(self, left, right):
# Implements the coercion protocol. # Implements the coercion protocol.
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
if builtins.is_float(left.type) or builtins.is_float(right.type): if builtins.is_float(left.type) or builtins.is_float(right.type):
@ -474,7 +474,7 @@ class Inferencer(algorithm.Visitor):
else: else:
typ = builtins.TInt() typ = builtins.TInt()
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
return left, right return
else: # conflicting types else: # conflicting types
printer = types.TypePrinter() printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note", note1 = diagnostic.Diagnostic("note",
@ -489,15 +489,9 @@ class Inferencer(algorithm.Visitor):
left.loc, [right.loc], left.loc, [right.loc],
[note1, note2]) [note1, note2])
self.engine.process(diag) self.engine.process(diag)
return left, right return
# On 1st invocation, return_type is always a type variable. return typ, typ, typ
# On further invocations, coerce will only ever refine the type,
# so this should never fail.
return_type.unify(typ)
return self._coerce_one(typ, left, other_node=right), \
self._coerce_one(typ, right, other_node=left)
def _order_by_pred(self, pred, left, right): def _order_by_pred(self, pred, left, right):
if pred(left.type): if pred(left.type):
@ -507,28 +501,26 @@ class Inferencer(algorithm.Visitor):
else: else:
assert False assert False
def visit_BinOpT(self, node): def _coerce_binop(self, op, left, right):
self.generic_visit(node) if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor,
ast.LShift, ast.RShift)): ast.LShift, ast.RShift)):
# bitwise operators require integers # bitwise operators require integers
for operand in (node.left, node.right): for operand in (left, right):
if not types.is_var(operand.type) and not builtins.is_int(operand.type): if not types.is_var(operand.type) and not builtins.is_int(operand.type):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"expected '{op}' operand to be of integer type, not {type}", "expected '{op}' operand to be of integer type, not {type}",
{"op": node.op.loc.source(), {"op": op.loc.source(),
"type": types.TypePrinter().name(operand.type)}, "type": types.TypePrinter().name(operand.type)},
node.op.loc, [operand.loc]) op.loc, [operand.loc])
self.engine.process(diag) self.engine.process(diag)
return return
node.left, node.right = \ return self._coerce_numeric(left, right)
self._coerce_numeric(node.type, node.left, node.right) elif isinstance(op, ast.Add):
elif isinstance(node.op, ast.Add):
# add works on numbers and also collections # add works on numbers and also collections
if builtins.is_collection(node.left.type) or builtins.is_collection(node.right.type): if builtins.is_collection(left.type) or builtins.is_collection(right.type):
collection, other = \ collection, other = \
self._order_by_pred(builtins.is_collection, node.left, node.right) self._order_by_pred(builtins.is_collection, left, right)
if types.is_tuple(collection.type): if types.is_tuple(collection.type):
pred, kind = types.is_tuple, "tuple" pred, kind = types.is_tuple, "tuple"
elif builtins.is_list(collection.type): elif builtins.is_list(collection.type):
@ -548,32 +540,32 @@ class Inferencer(algorithm.Visitor):
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"expected every '+' operand to be a {kind} in this context", "expected every '+' operand to be a {kind} in this context",
{"kind": kind}, {"kind": kind},
node.op.loc, [other.loc, collection.loc], op.loc, [other.loc, collection.loc],
[note1, note2]) [note1, note2])
self.engine.process(diag) self.engine.process(diag)
return return
if types.is_tuple(collection.type): if types.is_tuple(collection.type):
# should never fail # should never fail
node.type.unify(types.TTuple(node.left.type.find().elts + return types.TTuple(left.type.find().elts +
node.right.type.find().elts)) right.type.find().elts), left.type, right.type
elif builtins.is_list(collection.type): elif builtins.is_list(collection.type):
self._unify(node.left.type, node.right.type, self._unify(left.type, right.type,
node.left.loc, node.right.loc) left.loc, right.loc)
node.type.unify(node.left.type) # should never fail return left.type, left.type, right.type
else: else:
node.left, node.right = \ return self._coerce_numeric(left, right)
self._coerce_numeric(node.type, node.left, node.right) elif isinstance(op, ast.Mult):
elif isinstance(node.op, ast.Mult):
# mult works on numbers and also number & collection # mult works on numbers and also number & collection
if types.is_tuple(node.left.type) or types.is_tuple(node.right.type): if types.is_tuple(left.type) or types.is_tuple(right.type):
tuple_, other = self._order_by_pred(types.is_tuple, node.left, node.right) tuple_, other = self._order_by_pred(types.is_tuple, left, right)
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"py2llvm does not support passing tuples to '*'", {}, "py2llvm does not support passing tuples to '*'", {},
node.op.loc, [tuple_.loc]) op.loc, [tuple_.loc])
self.engine.process(diag) self.engine.process(diag)
elif builtins.is_list(node.left.type) or builtins.is_list(node.right.type): return
list_, other = self._order_by_pred(builtins.is_list, node.left, node.right) elif builtins.is_list(left.type) or builtins.is_list(right.type):
list_, other = self._order_by_pred(builtins.is_list, left, right)
if not builtins.is_int(other.type): if not builtins.is_int(other.type):
printer = types.TypePrinter() printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note", note1 = diagnostic.Diagnostic("note",
@ -586,25 +578,33 @@ class Inferencer(algorithm.Visitor):
other.loc) other.loc)
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"expected '*' operands to be a list and an integer in this context", {}, "expected '*' operands to be a list and an integer in this context", {},
node.op.loc, [list_.loc, other.loc], op.loc, [list_.loc, other.loc],
[note1, note2]) [note1, note2])
self.engine.process(diag) self.engine.process(diag)
return return
node.type.unify(list_.type)
return list_.type, left.type, right.type
else: else:
node.left, node.right = \ return self._coerce_numeric(left, right)
self._coerce_numeric(node.type, node.left, node.right) elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
elif isinstance(node.op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
# numeric operators work on any kind of number # numeric operators work on any kind of number
node.left, node.right = \ return self._coerce_numeric(left, right)
self._coerce_numeric(node.type, node.left, node.right)
else: # MatMult else: # MatMult
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"operator '{op}' is not supported", {"op": node.op.loc.source()}, "operator '{op}' is not supported", {"op": op.loc.source()},
node.op.loc) op.loc)
self.engine.process(diag) self.engine.process(diag)
return return
def visit_BinOpT(self, node):
self.generic_visit(node)
coerced = self._coerce_binop(node.op, node.left, node.right)
if coerced:
return_type, left_type, right_type = coerced
node.left = self._coerce_one(left_type, node.left, other_node=node.right)
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
node.type.unify(return_type) # should never fail
def visit_Assign(self, node): def visit_Assign(self, node):
self.generic_visit(node) self.generic_visit(node)
if len(node.targets) > 1: if len(node.targets) > 1:
@ -616,8 +616,45 @@ class Inferencer(algorithm.Visitor):
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
self.generic_visit(node) self.generic_visit(node)
self._unify(node.target.type, node.value.type, coerced = self._coerce_binop(node.op, node.target, node.value)
node.target.loc, node.value.loc) if coerced:
return_type, target_type, value_type = coerced
try:
node.target.type.unify(target_type)
except types.UnificationError as e:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"expression of type {typec}",
{"typec": printer.name(node.value.type)},
node.value.loc)
diag = diagnostic.Diagnostic("error",
"expression of type {typea} has to be coerced to {typeb}, "
"which makes assignment invalid",
{"typea": printer.name(node.target.type),
"typeb": printer.name(target_type)},
node.op.loc, [node.target.loc], [note])
self.engine.process(diag)
return
try:
node.target.type.unify(return_type)
except types.UnificationError as e:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"expression of type {typec}",
{"typec": printer.name(node.value.type)},
node.value.loc)
diag = diagnostic.Diagnostic("error",
"the result of this operation has type {typeb}, "
"which makes assignment to a slot of type {typea} invalid",
{"typea": printer.name(node.target.type),
"typeb": printer.name(return_type)},
node.op.loc, [node.target.loc], [note])
self.engine.process(diag)
return
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
def visit_For(self, node): def visit_For(self, node):
old_in_loop, self.in_loop = self.in_loop, True old_in_loop, self.in_loop = self.in_loop, True

View File

@ -24,3 +24,6 @@
1 + 1.0 1 + 1.0
# CHECK-L: 1:int(width='q):float + 1.0:float:float # CHECK-L: 1:int(width='q):float + 1.0:float:float
a = []; a += [1]
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))

View File

@ -33,3 +33,11 @@
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float # CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float # CHECK-L: ${LINE:+1}: note: expression that required coercion to float
[] - 1.0 [] - 1.0
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid
# CHECK-L: ${LINE:+1}: note: expression of type float
a = 1; a += 1.0
# CHECK-L: ${LINE:+2}: error: the result of this operation has type (int(width='a), float), which makes assignment to a slot of type (int(width='a),) invalid
# CHECK-L: ${LINE:+1}: note: expression of type (float,)
b = (1,); b += (1.0,)