compiler: Implement assigning binops for arrays

GitHub: Fixes #1579.
This commit is contained in:
David Nadlinger 2021-01-11 23:05:12 +01:00
parent 5b5db1433b
commit 96692791cf
2 changed files with 64 additions and 6 deletions

View File

@ -390,6 +390,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
if builtins.is_array(lhs.type):
name = type(node.op).__name__
def make_op(l, r):
return self.append(ir.Arith(node.op, l, r))
self._broadcast_binop(name, make_op, lhs.type, lhs, rhs, assign_to_lhs=True)
return
value = self.append(ir.Arith(node.op, lhs, rhs))
try:
self.current_assign = value
@ -1715,7 +1723,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Alloc([result_buffer, shape], node.type))
return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type)))
def _broadcast_binop(self, name, make_op, result_type, lhs, rhs):
def _broadcast_binop(self, name, make_op, result_type, lhs, rhs, assign_to_lhs):
# Broadcast scalars (broadcasting higher dimensions is not yet allowed in the
# language).
broadcast = False
@ -1736,7 +1744,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
builtins.TException("ValueError"),
ir.Constant("operands could not be broadcast together",
builtins.TStr())))
if assign_to_lhs:
result = lhs
else:
elt = result_type.find()["elt"]
result, _ = self._allocate_new_array(elt, shape)
func = self._get_array_elementwise_binop(name, make_op, result_type, lhs.type,
@ -1753,7 +1763,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
name = type(node.op).__name__
def make_op(l, r):
return self.append(ir.Arith(node.op, l, r))
return self._broadcast_binop(name, make_op, node.type, lhs, rhs)
return self._broadcast_binop(name, make_op, node.type, lhs, rhs,
assign_to_lhs=False)
elif builtins.is_numeric(node.type):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
@ -2446,7 +2457,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
self._invoke_arrayop(func, [result, args[0]])
insn = result
elif len(args) == 2:
insn = self._broadcast_binop(name, make_call, node.type, *args)
insn = self._broadcast_binop(name, make_call, node.type, *args,
assign_to_lhs=False)
else:
assert False, "Broadcasting for {} arguments not implemented".format(len)
else:

View File

@ -8,31 +8,61 @@ assert c[0] == 5
assert c[1] == 7
assert c[2] == 9
c += a
assert c[0] == 6
assert c[1] == 9
assert c[2] == 12
c = b - a
assert c[0] == 3
assert c[1] == 3
assert c[2] == 3
c -= a
assert c[0] == 2
assert c[1] == 1
assert c[2] == 0
c = a * b
assert c[0] == 4
assert c[1] == 10
assert c[2] == 18
c *= a
assert c[0] == 4
assert c[1] == 20
assert c[2] == 54
c = b // a
assert c[0] == 4
assert c[1] == 2
assert c[2] == 2
c //= a
assert c[0] == 4
assert c[1] == 1
assert c[2] == 0
c = a ** b
assert c[0] == 1
assert c[1] == 32
assert c[2] == 729
c **= a
assert c[0] == 1
assert c[1] == 1024
assert c[2] == 387420489
c = b % a
assert c[0] == 0
assert c[1] == 1
assert c[2] == 0
c %= a
assert c[0] == 0
assert c[1] == 1
assert c[2] == 0
cf = b / a
assert cf[0] == 4.0
assert cf[1] == 2.5
@ -43,6 +73,16 @@ assert cf2[0] == 5.0
assert cf2[1] == 4.5
assert cf2[2] == 5.0
cf2 += a
assert cf2[0] == 6.0
assert cf2[1] == 6.5
assert cf2[2] == 8.0
cf /= a
assert cf[0] == 4.0
assert cf[1] == 1.25
assert cf[2] == 2.0 / 3.0
d = array([[1, 2], [3, 4]])
e = array([[5, 6], [7, 8]])
f = d + e
@ -50,3 +90,9 @@ assert f[0][0] == 6
assert f[0][1] == 8
assert f[1][0] == 10
assert f[1][1] == 12
f += d
assert f[0][0] == 7
assert f[0][1] == 10
assert f[1][0] == 13
assert f[1][1] == 16