From faea886c44201c31d8ba27a9ec2a83f6b17caf16 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Sun, 2 Aug 2020 21:39:36 +0100 Subject: [PATCH] compiler: Implement array vs. scalar broadcasting --- .../compiler/transforms/artiq_ir_generator.py | 60 +++++++++++++------ artiq/compiler/transforms/inferencer.py | 19 +++--- artiq/test/lit/integration/array_broadcast.py | 55 +++++++++++++++++ 3 files changed, 107 insertions(+), 27 deletions(-) create mode 100644 artiq/test/lit/integration/array_broadcast.py diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index a77969fc9..e148ce184 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1513,17 +1513,30 @@ class ARTIQIRGenerator(algorithm.Visitor): # At this point, shapes are assumed to match; could just pass buffer # pointer for two of the three arrays as well. result_buffer = self.append(ir.GetAttr(result, "buffer")) - lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) - rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) shape = self.append(ir.GetAttr(result, "shape")) num_total_elts = self._get_total_array_len(shape) + if builtins.is_array(lhs.type): + lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) + def get_left(index): + return self.append(ir.GetElem(lhs_buffer, index)) + else: + def get_left(index): + return lhs + + if builtins.is_array(rhs.type): + rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) + def get_right(index): + return self.append(ir.GetElem(rhs_buffer, index)) + else: + def get_right(index): + return rhs + def loop_gen(index): - l = self.append(ir.GetElem(lhs_buffer, index)) - r = self.append(ir.GetElem(rhs_buffer, index)) - self.append( - ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l, - r)))) + l = get_left(index) + r = get_right(index) + result = self.append(ir.Arith(op, l, r)) + self.append(ir.SetElem(result_buffer, index, result)) return self.append( ir.Arith(ast.Add(loc=None), index, ir.Constant(1, self._size_type))) @@ -1700,20 +1713,29 @@ class ARTIQIRGenerator(algorithm.Visitor): lhs = self.visit(node.left) rhs = self.visit(node.right) - shape = self.append(ir.GetAttr(lhs, "shape")) - # TODO: Broadcasts; select the widest shape. - rhs_shape = self.append(ir.GetAttr(rhs, "shape")) - self._make_check( - self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), - lambda: self.alloc_exn( - builtins.TException("ValueError"), - ir.Constant("operands could not be broadcast together", - builtins.TStr()))) + # Broadcast scalars. + broadcast = False + array_arg = lhs + if not builtins.is_array(lhs.type): + broadcast = True + array_arg = rhs + elif not builtins.is_array(rhs.type): + broadcast = True + + shape = self.append(ir.GetAttr(array_arg, "shape")) + + if not broadcast: + rhs_shape = self.append(ir.GetAttr(rhs, "shape")) + self._make_check( + self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), + lambda: self.alloc_exn( + builtins.TException("ValueError"), + ir.Constant("operands could not be broadcast together", + builtins.TStr()))) + result = self._allocate_new_array(node.type.find()["elt"], shape) - - func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type) + func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type) self._invoke_arrayop(func, [result, lhs, rhs]) - return result elif builtins.is_numeric(node.type): lhs = self.visit(node.left) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 7649682c2..445d64bce 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -480,10 +480,10 @@ class Inferencer(algorithm.Visitor): return typ.find()["num_dims"].value return 0 - # TODO: Broadcasting. left_dims = num_dims(left.type) right_dims = num_dims(right.type) - if left_dims != right_dims: + if left_dims != right_dims and left_dims != 0 and right_dims != 0: + # Mismatch (only scalar broadcast supported for now). note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}", {"num_dims": left_dims}, left.loc) note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}", @@ -495,16 +495,19 @@ class Inferencer(algorithm.Visitor): return def map_node_type(typ): - if builtins.is_array(typ): - return typ.find()["elt"] - else: - # This is (if later valid) a single value broadcast across the array. + if not builtins.is_array(typ): + # This is a single value broadcast across the array. return typ + return typ.find()["elt"] + # Figure out result type, handling broadcasts. + result_dims = left_dims if left_dims else right_dims def map_return(typ): elt = builtins.TFloat() if isinstance(op, ast.Div) else typ - a = builtins.TArray(elt=elt, num_dims=left_dims) - return (a, a, a) + result = builtins.TArray(elt=elt, num_dims=result_dims) + left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt + right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt + return (result, left, right) return self._coerce_numeric((left, right), map_return=map_return, diff --git a/artiq/test/lit/integration/array_broadcast.py b/artiq/test/lit/integration/array_broadcast.py new file mode 100644 index 000000000..d7cbc5998 --- /dev/null +++ b/artiq/test/lit/integration/array_broadcast.py @@ -0,0 +1,55 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s + +a = array([1, 2, 3]) + +c = a + 1 +assert c[0] == 2 +assert c[1] == 3 +assert c[2] == 4 + +c = 1 - a +assert c[0] == 0 +assert c[1] == -1 +assert c[2] == -2 + +c = a * 1 +assert c[0] == 1 +assert c[1] == 2 +assert c[2] == 3 + +c = a // 2 +assert c[0] == 0 +assert c[1] == 1 +assert c[2] == 1 + +c = a ** 2 +assert c[0] == 1 +assert c[1] == 4 +assert c[2] == 9 + +c = 2 ** a +assert c[0] == 2 +assert c[1] == 4 +assert c[2] == 8 + +c = a % 2 +assert c[0] == 1 +assert c[1] == 0 +assert c[2] == 1 + +cf = a / 2 +assert cf[0] == 0.5 +assert cf[1] == 1.0 +assert cf[2] == 1.5 + +cf2 = 2 / array([1, 2, 4]) +assert cf2[0] == 2.0 +assert cf2[1] == 1.0 +assert cf2[2] == 0.5 + +d = array([[1, 2], [3, 4]]) +e = d + 1 +assert e[0][0] == 2 +assert e[0][1] == 3 +assert e[1][0] == 4 +assert e[1][1] == 5