compiler: Implement assigning binops for arrays

GitHub: Fixes #1579.
This commit is contained in:
David Nadlinger 2021-01-11 23:05:12 +01:00
parent 5b5db1433b
commit 96692791cf
2 changed files with 64 additions and 6 deletions

View File

@ -390,6 +390,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
lhs = self.visit(node.target) lhs = self.visit(node.target)
rhs = self.visit(node.value) rhs = self.visit(node.value)
if builtins.is_array(lhs.type):
name = type(node.op).__name__
def make_op(l, r):
return self.append(ir.Arith(node.op, l, r))
self._broadcast_binop(name, make_op, lhs.type, lhs, rhs, assign_to_lhs=True)
return
value = self.append(ir.Arith(node.op, lhs, rhs)) value = self.append(ir.Arith(node.op, lhs, rhs))
try: try:
self.current_assign = value self.current_assign = value
@ -1715,7 +1723,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Alloc([result_buffer, shape], node.type)) return self.append(ir.Alloc([result_buffer, shape], node.type))
return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type))) return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type)))
def _broadcast_binop(self, name, make_op, result_type, lhs, rhs): def _broadcast_binop(self, name, make_op, result_type, lhs, rhs, assign_to_lhs):
# Broadcast scalars (broadcasting higher dimensions is not yet allowed in the # Broadcast scalars (broadcasting higher dimensions is not yet allowed in the
# language). # language).
broadcast = False broadcast = False
@ -1736,9 +1744,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
builtins.TException("ValueError"), builtins.TException("ValueError"),
ir.Constant("operands could not be broadcast together", ir.Constant("operands could not be broadcast together",
builtins.TStr()))) builtins.TStr())))
if assign_to_lhs:
elt = result_type.find()["elt"] result = lhs
result, _ = self._allocate_new_array(elt, shape) else:
elt = result_type.find()["elt"]
result, _ = self._allocate_new_array(elt, shape)
func = self._get_array_elementwise_binop(name, make_op, result_type, lhs.type, func = self._get_array_elementwise_binop(name, make_op, result_type, lhs.type,
rhs.type) rhs.type)
self._invoke_arrayop(func, [result, lhs, rhs]) self._invoke_arrayop(func, [result, lhs, rhs])
@ -1753,7 +1763,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
name = type(node.op).__name__ name = type(node.op).__name__
def make_op(l, r): def make_op(l, r):
return self.append(ir.Arith(node.op, l, r)) return self.append(ir.Arith(node.op, l, r))
return self._broadcast_binop(name, make_op, node.type, lhs, rhs) return self._broadcast_binop(name, make_op, node.type, lhs, rhs,
assign_to_lhs=False)
elif builtins.is_numeric(node.type): elif builtins.is_numeric(node.type):
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.right) rhs = self.visit(node.right)
@ -2446,7 +2457,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
self._invoke_arrayop(func, [result, args[0]]) self._invoke_arrayop(func, [result, args[0]])
insn = result insn = result
elif len(args) == 2: elif len(args) == 2:
insn = self._broadcast_binop(name, make_call, node.type, *args) insn = self._broadcast_binop(name, make_call, node.type, *args,
assign_to_lhs=False)
else: else:
assert False, "Broadcasting for {} arguments not implemented".format(len) assert False, "Broadcasting for {} arguments not implemented".format(len)
else: else:

View File

@ -8,31 +8,61 @@ assert c[0] == 5
assert c[1] == 7 assert c[1] == 7
assert c[2] == 9 assert c[2] == 9
c += a
assert c[0] == 6
assert c[1] == 9
assert c[2] == 12
c = b - a c = b - a
assert c[0] == 3 assert c[0] == 3
assert c[1] == 3 assert c[1] == 3
assert c[2] == 3 assert c[2] == 3
c -= a
assert c[0] == 2
assert c[1] == 1
assert c[2] == 0
c = a * b c = a * b
assert c[0] == 4 assert c[0] == 4
assert c[1] == 10 assert c[1] == 10
assert c[2] == 18 assert c[2] == 18
c *= a
assert c[0] == 4
assert c[1] == 20
assert c[2] == 54
c = b // a c = b // a
assert c[0] == 4 assert c[0] == 4
assert c[1] == 2 assert c[1] == 2
assert c[2] == 2 assert c[2] == 2
c //= a
assert c[0] == 4
assert c[1] == 1
assert c[2] == 0
c = a ** b c = a ** b
assert c[0] == 1 assert c[0] == 1
assert c[1] == 32 assert c[1] == 32
assert c[2] == 729 assert c[2] == 729
c **= a
assert c[0] == 1
assert c[1] == 1024
assert c[2] == 387420489
c = b % a c = b % a
assert c[0] == 0 assert c[0] == 0
assert c[1] == 1 assert c[1] == 1
assert c[2] == 0 assert c[2] == 0
c %= a
assert c[0] == 0
assert c[1] == 1
assert c[2] == 0
cf = b / a cf = b / a
assert cf[0] == 4.0 assert cf[0] == 4.0
assert cf[1] == 2.5 assert cf[1] == 2.5
@ -43,6 +73,16 @@ assert cf2[0] == 5.0
assert cf2[1] == 4.5 assert cf2[1] == 4.5
assert cf2[2] == 5.0 assert cf2[2] == 5.0
cf2 += a
assert cf2[0] == 6.0
assert cf2[1] == 6.5
assert cf2[2] == 8.0
cf /= a
assert cf[0] == 4.0
assert cf[1] == 1.25
assert cf[2] == 2.0 / 3.0
d = array([[1, 2], [3, 4]]) d = array([[1, 2], [3, 4]])
e = array([[5, 6], [7, 8]]) e = array([[5, 6], [7, 8]])
f = d + e f = d + e
@ -50,3 +90,9 @@ assert f[0][0] == 6
assert f[0][1] == 8 assert f[0][1] == 8
assert f[1][0] == 10 assert f[1][0] == 10
assert f[1][1] == 12 assert f[1][1] == 12
f += d
assert f[0][0] == 7
assert f[0][1] == 10
assert f[1][0] == 13
assert f[1][1] == 16