forked from M-Labs/artiq
compiler: Implement array vs. scalar broadcasting
This commit is contained in:
parent
56a872ccc0
commit
faea886c44
|
@ -1513,17 +1513,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
# At this point, shapes are assumed to match; could just pass buffer
|
||||
# pointer for two of the three arrays as well.
|
||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||
shape = self.append(ir.GetAttr(result, "shape"))
|
||||
num_total_elts = self._get_total_array_len(shape)
|
||||
|
||||
if builtins.is_array(lhs.type):
|
||||
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||
def get_left(index):
|
||||
return self.append(ir.GetElem(lhs_buffer, index))
|
||||
else:
|
||||
def get_left(index):
|
||||
return lhs
|
||||
|
||||
if builtins.is_array(rhs.type):
|
||||
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||
def get_right(index):
|
||||
return self.append(ir.GetElem(rhs_buffer, index))
|
||||
else:
|
||||
def get_right(index):
|
||||
return rhs
|
||||
|
||||
def loop_gen(index):
|
||||
l = self.append(ir.GetElem(lhs_buffer, index))
|
||||
r = self.append(ir.GetElem(rhs_buffer, index))
|
||||
self.append(
|
||||
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l,
|
||||
r))))
|
||||
l = get_left(index)
|
||||
r = get_right(index)
|
||||
result = self.append(ir.Arith(op, l, r))
|
||||
self.append(ir.SetElem(result_buffer, index, result))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, self._size_type)))
|
||||
|
@ -1700,20 +1713,29 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
|
||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
||||
# TODO: Broadcasts; select the widest shape.
|
||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||
lambda: self.alloc_exn(
|
||||
builtins.TException("ValueError"),
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
# Broadcast scalars.
|
||||
broadcast = False
|
||||
array_arg = lhs
|
||||
if not builtins.is_array(lhs.type):
|
||||
broadcast = True
|
||||
array_arg = rhs
|
||||
elif not builtins.is_array(rhs.type):
|
||||
broadcast = True
|
||||
|
||||
shape = self.append(ir.GetAttr(array_arg, "shape"))
|
||||
|
||||
if not broadcast:
|
||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||
lambda: self.alloc_exn(
|
||||
builtins.TException("ValueError"),
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
|
||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||
|
||||
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
||||
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
|
||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||
|
||||
return result
|
||||
elif builtins.is_numeric(node.type):
|
||||
lhs = self.visit(node.left)
|
||||
|
|
|
@ -480,10 +480,10 @@ class Inferencer(algorithm.Visitor):
|
|||
return typ.find()["num_dims"].value
|
||||
return 0
|
||||
|
||||
# TODO: Broadcasting.
|
||||
left_dims = num_dims(left.type)
|
||||
right_dims = num_dims(right.type)
|
||||
if left_dims != right_dims:
|
||||
if left_dims != right_dims and left_dims != 0 and right_dims != 0:
|
||||
# Mismatch (only scalar broadcast supported for now).
|
||||
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||
{"num_dims": left_dims}, left.loc)
|
||||
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||
|
@ -495,16 +495,19 @@ class Inferencer(algorithm.Visitor):
|
|||
return
|
||||
|
||||
def map_node_type(typ):
|
||||
if builtins.is_array(typ):
|
||||
return typ.find()["elt"]
|
||||
else:
|
||||
# This is (if later valid) a single value broadcast across the array.
|
||||
if not builtins.is_array(typ):
|
||||
# This is a single value broadcast across the array.
|
||||
return typ
|
||||
return typ.find()["elt"]
|
||||
|
||||
# Figure out result type, handling broadcasts.
|
||||
result_dims = left_dims if left_dims else right_dims
|
||||
def map_return(typ):
|
||||
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
|
||||
a = builtins.TArray(elt=elt, num_dims=left_dims)
|
||||
return (a, a, a)
|
||||
result = builtins.TArray(elt=elt, num_dims=result_dims)
|
||||
left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt
|
||||
right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt
|
||||
return (result, left, right)
|
||||
|
||||
return self._coerce_numeric((left, right),
|
||||
map_return=map_return,
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
|
||||
a = array([1, 2, 3])
|
||||
|
||||
c = a + 1
|
||||
assert c[0] == 2
|
||||
assert c[1] == 3
|
||||
assert c[2] == 4
|
||||
|
||||
c = 1 - a
|
||||
assert c[0] == 0
|
||||
assert c[1] == -1
|
||||
assert c[2] == -2
|
||||
|
||||
c = a * 1
|
||||
assert c[0] == 1
|
||||
assert c[1] == 2
|
||||
assert c[2] == 3
|
||||
|
||||
c = a // 2
|
||||
assert c[0] == 0
|
||||
assert c[1] == 1
|
||||
assert c[2] == 1
|
||||
|
||||
c = a ** 2
|
||||
assert c[0] == 1
|
||||
assert c[1] == 4
|
||||
assert c[2] == 9
|
||||
|
||||
c = 2 ** a
|
||||
assert c[0] == 2
|
||||
assert c[1] == 4
|
||||
assert c[2] == 8
|
||||
|
||||
c = a % 2
|
||||
assert c[0] == 1
|
||||
assert c[1] == 0
|
||||
assert c[2] == 1
|
||||
|
||||
cf = a / 2
|
||||
assert cf[0] == 0.5
|
||||
assert cf[1] == 1.0
|
||||
assert cf[2] == 1.5
|
||||
|
||||
cf2 = 2 / array([1, 2, 4])
|
||||
assert cf2[0] == 2.0
|
||||
assert cf2[1] == 1.0
|
||||
assert cf2[2] == 0.5
|
||||
|
||||
d = array([[1, 2], [3, 4]])
|
||||
e = d + 1
|
||||
assert e[0][0] == 2
|
||||
assert e[0][1] == 3
|
||||
assert e[1][0] == 4
|
||||
assert e[1][1] == 5
|
Loading…
Reference in New Issue