forked from M-Labs/artiq
compiler: Insert array binop shape check in caller for location information
This commit is contained in:
parent
ef260adca8
commit
56a872ccc0
|
@ -1510,21 +1510,12 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
if name not in self.array_binop_funcs:
|
||||
|
||||
def body_gen(result, lhs, rhs):
|
||||
# TODO: Move into caller for correct location information (or pass)?
|
||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||
lambda: self.alloc_exn(
|
||||
builtins.TException("ValueError"),
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
# We assume result has correct shape; could just pass buffer pointer
|
||||
# as well.
|
||||
|
||||
# At this point, shapes are assumed to match; could just pass buffer
|
||||
# pointer for two of the three arrays as well.
|
||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||
shape = self.append(ir.GetAttr(result, "shape"))
|
||||
num_total_elts = self._get_total_array_len(shape)
|
||||
|
||||
def loop_gen(index):
|
||||
|
@ -1709,9 +1700,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
|
||||
# Array op implementation will check for matching shape.
|
||||
# TODO: Broadcasts; select the widest shape.
|
||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||
# TODO: Broadcasts; select the widest shape.
|
||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||
lambda: self.alloc_exn(
|
||||
builtins.TException("ValueError"),
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
|
||||
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
||||
|
|
Loading…
Reference in New Issue