compiler: Implement array vs. scalar broadcasting

This commit is contained in:
David Nadlinger 2020-08-02 21:39:36 +01:00
parent 56a872ccc0
commit faea886c44
3 changed files with 107 additions and 27 deletions

View File

@ -1513,17 +1513,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
# At this point, shapes are assumed to match; could just pass buffer
# pointer for two of the three arrays as well.
result_buffer = self.append(ir.GetAttr(result, "buffer"))
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
shape = self.append(ir.GetAttr(result, "shape"))
num_total_elts = self._get_total_array_len(shape)
if builtins.is_array(lhs.type):
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
def get_left(index):
return self.append(ir.GetElem(lhs_buffer, index))
else:
def get_left(index):
return lhs
if builtins.is_array(rhs.type):
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
def get_right(index):
return self.append(ir.GetElem(rhs_buffer, index))
else:
def get_right(index):
return rhs
def loop_gen(index):
l = self.append(ir.GetElem(lhs_buffer, index))
r = self.append(ir.GetElem(rhs_buffer, index))
self.append(
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l,
r))))
l = get_left(index)
r = get_right(index)
result = self.append(ir.Arith(op, l, r))
self.append(ir.SetElem(result_buffer, index, result))
return self.append(
ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, self._size_type)))
@ -1700,8 +1713,18 @@ class ARTIQIRGenerator(algorithm.Visitor):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
shape = self.append(ir.GetAttr(lhs, "shape"))
# TODO: Broadcasts; select the widest shape.
# Broadcast scalars.
broadcast = False
array_arg = lhs
if not builtins.is_array(lhs.type):
broadcast = True
array_arg = rhs
elif not builtins.is_array(rhs.type):
broadcast = True
shape = self.append(ir.GetAttr(array_arg, "shape"))
if not broadcast:
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
self._make_check(
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
@ -1709,11 +1732,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
builtins.TException("ValueError"),
ir.Constant("operands could not be broadcast together",
builtins.TStr())))
result = self._allocate_new_array(node.type.find()["elt"], shape)
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
self._invoke_arrayop(func, [result, lhs, rhs])
return result
elif builtins.is_numeric(node.type):
lhs = self.visit(node.left)

View File

@ -480,10 +480,10 @@ class Inferencer(algorithm.Visitor):
return typ.find()["num_dims"].value
return 0
# TODO: Broadcasting.
left_dims = num_dims(left.type)
right_dims = num_dims(right.type)
if left_dims != right_dims:
if left_dims != right_dims and left_dims != 0 and right_dims != 0:
# Mismatch (only scalar broadcast supported for now).
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
{"num_dims": left_dims}, left.loc)
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
@ -495,16 +495,19 @@ class Inferencer(algorithm.Visitor):
return
def map_node_type(typ):
if builtins.is_array(typ):
return typ.find()["elt"]
else:
# This is (if later valid) a single value broadcast across the array.
if not builtins.is_array(typ):
# This is a single value broadcast across the array.
return typ
return typ.find()["elt"]
# Figure out result type, handling broadcasts.
result_dims = left_dims if left_dims else right_dims
def map_return(typ):
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
a = builtins.TArray(elt=elt, num_dims=left_dims)
return (a, a, a)
result = builtins.TArray(elt=elt, num_dims=result_dims)
left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt
right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt
return (result, left, right)
return self._coerce_numeric((left, right),
map_return=map_return,

View File

@ -0,0 +1,55 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
a = array([1, 2, 3])
c = a + 1
assert c[0] == 2
assert c[1] == 3
assert c[2] == 4
c = 1 - a
assert c[0] == 0
assert c[1] == -1
assert c[2] == -2
c = a * 1
assert c[0] == 1
assert c[1] == 2
assert c[2] == 3
c = a // 2
assert c[0] == 0
assert c[1] == 1
assert c[2] == 1
c = a ** 2
assert c[0] == 1
assert c[1] == 4
assert c[2] == 9
c = 2 ** a
assert c[0] == 2
assert c[1] == 4
assert c[2] == 8
c = a % 2
assert c[0] == 1
assert c[1] == 0
assert c[2] == 1
cf = a / 2
assert cf[0] == 0.5
assert cf[1] == 1.0
assert cf[2] == 1.5
cf2 = 2 / array([1, 2, 4])
assert cf2[0] == 2.0
assert cf2[1] == 1.0
assert cf2[2] == 0.5
d = array([[1, 2], [3, 4]])
e = d + 1
assert e[0][0] == 2
assert e[0][1] == 3
assert e[1][0] == 4
assert e[1][1] == 5