diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index fa5912999..a77969fc9 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -1510,21 +1510,12 @@ class ARTIQIRGenerator(algorithm.Visitor): if name not in self.array_binop_funcs: def body_gen(result, lhs, rhs): - # TODO: Move into caller for correct location information (or pass)? - shape = self.append(ir.GetAttr(lhs, "shape")) - rhs_shape = self.append(ir.GetAttr(rhs, "shape")) - self._make_check( - self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), - lambda: self.alloc_exn( - builtins.TException("ValueError"), - ir.Constant("operands could not be broadcast together", - builtins.TStr()))) - # We assume result has correct shape; could just pass buffer pointer - # as well. - + # At this point, shapes are assumed to match; could just pass buffer + # pointer for two of the three arrays as well. result_buffer = self.append(ir.GetAttr(result, "buffer")) lhs_buffer = self.append(ir.GetAttr(lhs, "buffer")) rhs_buffer = self.append(ir.GetAttr(rhs, "buffer")) + shape = self.append(ir.GetAttr(result, "shape")) num_total_elts = self._get_total_array_len(shape) def loop_gen(index): @@ -1709,9 +1700,15 @@ class ARTIQIRGenerator(algorithm.Visitor): lhs = self.visit(node.left) rhs = self.visit(node.right) - # Array op implementation will check for matching shape. - # TODO: Broadcasts; select the widest shape. shape = self.append(ir.GetAttr(lhs, "shape")) + # TODO: Broadcasts; select the widest shape. + rhs_shape = self.append(ir.GetAttr(rhs, "shape")) + self._make_check( + self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)), + lambda: self.alloc_exn( + builtins.TException("ValueError"), + ir.Constant("operands could not be broadcast together", + builtins.TStr()))) result = self._allocate_new_array(node.type.find()["elt"], shape) func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)