From 96692791cfb7f80baa97ed2d6bde1a7db6365b8f Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Mon, 11 Jan 2021 23:05:12 +0100 Subject: [PATCH] compiler: Implement assigning binops for arrays GitHub: Fixes #1579. --- .../compiler/transforms/artiq_ir_generator.py | 24 +++++++--- artiq/test/lit/integration/array_binops.py | 46 +++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index afa75b61b..afccb5aba 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -390,6 +390,14 @@ class ARTIQIRGenerator(algorithm.Visitor): def visit_AugAssign(self, node): lhs = self.visit(node.target) rhs = self.visit(node.value) + + if builtins.is_array(lhs.type): + name = type(node.op).__name__ + def make_op(l, r): + return self.append(ir.Arith(node.op, l, r)) + self._broadcast_binop(name, make_op, lhs.type, lhs, rhs, assign_to_lhs=True) + return + value = self.append(ir.Arith(node.op, lhs, rhs)) try: self.current_assign = value @@ -1715,7 +1723,7 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Alloc([result_buffer, shape], node.type)) return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type))) - def _broadcast_binop(self, name, make_op, result_type, lhs, rhs): + def _broadcast_binop(self, name, make_op, result_type, lhs, rhs, assign_to_lhs): # Broadcast scalars (broadcasting higher dimensions is not yet allowed in the # language). broadcast = False @@ -1736,9 +1744,11 @@ class ARTIQIRGenerator(algorithm.Visitor): builtins.TException("ValueError"), ir.Constant("operands could not be broadcast together", builtins.TStr()))) - - elt = result_type.find()["elt"] - result, _ = self._allocate_new_array(elt, shape) + if assign_to_lhs: + result = lhs + else: + elt = result_type.find()["elt"] + result, _ = self._allocate_new_array(elt, shape) func = self._get_array_elementwise_binop(name, make_op, result_type, lhs.type, rhs.type) self._invoke_arrayop(func, [result, lhs, rhs]) @@ -1753,7 +1763,8 @@ class ARTIQIRGenerator(algorithm.Visitor): name = type(node.op).__name__ def make_op(l, r): return self.append(ir.Arith(node.op, l, r)) - return self._broadcast_binop(name, make_op, node.type, lhs, rhs) + return self._broadcast_binop(name, make_op, node.type, lhs, rhs, + assign_to_lhs=False) elif builtins.is_numeric(node.type): lhs = self.visit(node.left) rhs = self.visit(node.right) @@ -2446,7 +2457,8 @@ class ARTIQIRGenerator(algorithm.Visitor): self._invoke_arrayop(func, [result, args[0]]) insn = result elif len(args) == 2: - insn = self._broadcast_binop(name, make_call, node.type, *args) + insn = self._broadcast_binop(name, make_call, node.type, *args, + assign_to_lhs=False) else: assert False, "Broadcasting for {} arguments not implemented".format(len) else: diff --git a/artiq/test/lit/integration/array_binops.py b/artiq/test/lit/integration/array_binops.py index e60052277..cfca32d92 100644 --- a/artiq/test/lit/integration/array_binops.py +++ b/artiq/test/lit/integration/array_binops.py @@ -8,31 +8,61 @@ assert c[0] == 5 assert c[1] == 7 assert c[2] == 9 +c += a +assert c[0] == 6 +assert c[1] == 9 +assert c[2] == 12 + c = b - a assert c[0] == 3 assert c[1] == 3 assert c[2] == 3 +c -= a +assert c[0] == 2 +assert c[1] == 1 +assert c[2] == 0 + c = a * b assert c[0] == 4 assert c[1] == 10 assert c[2] == 18 +c *= a +assert c[0] == 4 +assert c[1] == 20 +assert c[2] == 54 + c = b // a assert c[0] == 4 assert c[1] == 2 assert c[2] == 2 +c //= a +assert c[0] == 4 +assert c[1] == 1 +assert c[2] == 0 + c = a ** b assert c[0] == 1 assert c[1] == 32 assert c[2] == 729 +c **= a +assert c[0] == 1 +assert c[1] == 1024 +assert c[2] == 387420489 + c = b % a assert c[0] == 0 assert c[1] == 1 assert c[2] == 0 +c %= a +assert c[0] == 0 +assert c[1] == 1 +assert c[2] == 0 + cf = b / a assert cf[0] == 4.0 assert cf[1] == 2.5 @@ -43,6 +73,16 @@ assert cf2[0] == 5.0 assert cf2[1] == 4.5 assert cf2[2] == 5.0 +cf2 += a +assert cf2[0] == 6.0 +assert cf2[1] == 6.5 +assert cf2[2] == 8.0 + +cf /= a +assert cf[0] == 4.0 +assert cf[1] == 1.25 +assert cf[2] == 2.0 / 3.0 + d = array([[1, 2], [3, 4]]) e = array([[5, 6], [7, 8]]) f = d + e @@ -50,3 +90,9 @@ assert f[0][0] == 6 assert f[0][1] == 8 assert f[1][0] == 10 assert f[1][1] == 12 + +f += d +assert f[0][0] == 7 +assert f[0][1] == 10 +assert f[1][0] == 13 +assert f[1][1] == 16