From faea886c44201c31d8ba27a9ec2a83f6b17caf16 Mon Sep 17 00:00:00 2001
From: David Nadlinger <code@klickverbot.at>
Date: Sun, 2 Aug 2020 21:39:36 +0100
Subject: [PATCH] compiler: Implement array vs. scalar broadcasting

---
 .../compiler/transforms/artiq_ir_generator.py | 60 +++++++++++++------
 artiq/compiler/transforms/inferencer.py       | 19 +++---
 artiq/test/lit/integration/array_broadcast.py | 55 +++++++++++++++++
 3 files changed, 107 insertions(+), 27 deletions(-)
 create mode 100644 artiq/test/lit/integration/array_broadcast.py

diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py
index a77969fc9..e148ce184 100644
--- a/artiq/compiler/transforms/artiq_ir_generator.py
+++ b/artiq/compiler/transforms/artiq_ir_generator.py
@@ -1513,17 +1513,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
                 # At this point, shapes are assumed to match; could just pass buffer
                 # pointer for two of the three arrays as well.
                 result_buffer = self.append(ir.GetAttr(result, "buffer"))
-                lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
-                rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
                 shape = self.append(ir.GetAttr(result, "shape"))
                 num_total_elts = self._get_total_array_len(shape)
 
+                if builtins.is_array(lhs.type):
+                    lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
+                    def get_left(index):
+                        return self.append(ir.GetElem(lhs_buffer, index))
+                else:
+                    def get_left(index):
+                        return lhs
+
+                if builtins.is_array(rhs.type):
+                    rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
+                    def get_right(index):
+                        return self.append(ir.GetElem(rhs_buffer, index))
+                else:
+                    def get_right(index):
+                        return rhs
+
                 def loop_gen(index):
-                    l = self.append(ir.GetElem(lhs_buffer, index))
-                    r = self.append(ir.GetElem(rhs_buffer, index))
-                    self.append(
-                        ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l,
-                                                                              r))))
+                    l = get_left(index)
+                    r = get_right(index)
+                    result = self.append(ir.Arith(op, l, r))
+                    self.append(ir.SetElem(result_buffer, index, result))
                     return self.append(
                         ir.Arith(ast.Add(loc=None), index,
                                  ir.Constant(1, self._size_type)))
@@ -1700,20 +1713,29 @@ class ARTIQIRGenerator(algorithm.Visitor):
             lhs = self.visit(node.left)
             rhs = self.visit(node.right)
 
-            shape = self.append(ir.GetAttr(lhs, "shape"))
-            # TODO: Broadcasts; select the widest shape.
-            rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
-            self._make_check(
-                self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
-                lambda: self.alloc_exn(
-                    builtins.TException("ValueError"),
-                    ir.Constant("operands could not be broadcast together",
-                                builtins.TStr())))
+            # Broadcast scalars.
+            broadcast = False
+            array_arg = lhs
+            if not builtins.is_array(lhs.type):
+                broadcast = True
+                array_arg = rhs
+            elif not builtins.is_array(rhs.type):
+                broadcast = True
+
+            shape = self.append(ir.GetAttr(array_arg, "shape"))
+
+            if not broadcast:
+                rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
+                self._make_check(
+                    self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
+                    lambda: self.alloc_exn(
+                        builtins.TException("ValueError"),
+                        ir.Constant("operands could not be broadcast together",
+                                    builtins.TStr())))
+
             result = self._allocate_new_array(node.type.find()["elt"], shape)
-
-            func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
+            func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
             self._invoke_arrayop(func, [result, lhs, rhs])
-
             return result
         elif builtins.is_numeric(node.type):
             lhs = self.visit(node.left)
diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py
index 7649682c2..445d64bce 100644
--- a/artiq/compiler/transforms/inferencer.py
+++ b/artiq/compiler/transforms/inferencer.py
@@ -480,10 +480,10 @@ class Inferencer(algorithm.Visitor):
                     return typ.find()["num_dims"].value
                 return 0
 
-            # TODO: Broadcasting.
             left_dims = num_dims(left.type)
             right_dims = num_dims(right.type)
-            if left_dims != right_dims:
+            if left_dims != right_dims and left_dims != 0 and right_dims != 0:
+                # Mismatch (only scalar broadcast supported for now).
                 note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
                                               {"num_dims": left_dims}, left.loc)
                 note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
@@ -495,16 +495,19 @@ class Inferencer(algorithm.Visitor):
                 return
 
             def map_node_type(typ):
-                if builtins.is_array(typ):
-                    return typ.find()["elt"]
-                else:
-                    # This is (if later valid) a single value broadcast across the array.
+                if not builtins.is_array(typ):
+                    # This is a single value broadcast across the array.
                     return typ
+                return typ.find()["elt"]
 
+            # Figure out result type, handling broadcasts.
+            result_dims = left_dims if left_dims else right_dims
             def map_return(typ):
                 elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
-                a = builtins.TArray(elt=elt, num_dims=left_dims)
-                return (a, a, a)
+                result = builtins.TArray(elt=elt, num_dims=result_dims)
+                left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt
+                right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt
+                return (result, left, right)
 
             return self._coerce_numeric((left, right),
                                         map_return=map_return,
diff --git a/artiq/test/lit/integration/array_broadcast.py b/artiq/test/lit/integration/array_broadcast.py
new file mode 100644
index 000000000..d7cbc5998
--- /dev/null
+++ b/artiq/test/lit/integration/array_broadcast.py
@@ -0,0 +1,55 @@
+# RUN: %python -m artiq.compiler.testbench.jit %s
+
+a = array([1, 2, 3])
+
+c = a + 1
+assert c[0] == 2
+assert c[1] == 3
+assert c[2] == 4
+
+c = 1 - a
+assert c[0] == 0
+assert c[1] == -1
+assert c[2] == -2
+
+c = a * 1
+assert c[0] == 1
+assert c[1] == 2
+assert c[2] == 3
+
+c = a // 2
+assert c[0] == 0
+assert c[1] == 1
+assert c[2] == 1
+
+c = a ** 2
+assert c[0] == 1
+assert c[1] == 4
+assert c[2] == 9
+
+c = 2 ** a
+assert c[0] == 2
+assert c[1] == 4
+assert c[2] == 8
+
+c = a % 2
+assert c[0] == 1
+assert c[1] == 0
+assert c[2] == 1
+
+cf = a / 2
+assert cf[0] == 0.5
+assert cf[1] == 1.0
+assert cf[2] == 1.5
+
+cf2 = 2 / array([1, 2, 4])
+assert cf2[0] == 2.0
+assert cf2[1] == 1.0
+assert cf2[2] == 0.5
+
+d = array([[1, 2], [3, 4]])
+e = d + 1
+assert e[0][0] == 2
+assert e[0][1] == 3
+assert e[1][0] == 4
+assert e[1][1] == 5