forked from M-Labs/artiq
compiler: Implement array vs. scalar broadcasting
This commit is contained in:
parent
56a872ccc0
commit
faea886c44
|
@ -1513,17 +1513,30 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
# At this point, shapes are assumed to match; could just pass buffer
|
# At this point, shapes are assumed to match; could just pass buffer
|
||||||
# pointer for two of the three arrays as well.
|
# pointer for two of the three arrays as well.
|
||||||
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
result_buffer = self.append(ir.GetAttr(result, "buffer"))
|
||||||
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
|
||||||
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
|
||||||
shape = self.append(ir.GetAttr(result, "shape"))
|
shape = self.append(ir.GetAttr(result, "shape"))
|
||||||
num_total_elts = self._get_total_array_len(shape)
|
num_total_elts = self._get_total_array_len(shape)
|
||||||
|
|
||||||
|
if builtins.is_array(lhs.type):
|
||||||
|
lhs_buffer = self.append(ir.GetAttr(lhs, "buffer"))
|
||||||
|
def get_left(index):
|
||||||
|
return self.append(ir.GetElem(lhs_buffer, index))
|
||||||
|
else:
|
||||||
|
def get_left(index):
|
||||||
|
return lhs
|
||||||
|
|
||||||
|
if builtins.is_array(rhs.type):
|
||||||
|
rhs_buffer = self.append(ir.GetAttr(rhs, "buffer"))
|
||||||
|
def get_right(index):
|
||||||
|
return self.append(ir.GetElem(rhs_buffer, index))
|
||||||
|
else:
|
||||||
|
def get_right(index):
|
||||||
|
return rhs
|
||||||
|
|
||||||
def loop_gen(index):
|
def loop_gen(index):
|
||||||
l = self.append(ir.GetElem(lhs_buffer, index))
|
l = get_left(index)
|
||||||
r = self.append(ir.GetElem(rhs_buffer, index))
|
r = get_right(index)
|
||||||
self.append(
|
result = self.append(ir.Arith(op, l, r))
|
||||||
ir.SetElem(result_buffer, index, self.append(ir.Arith(op, l,
|
self.append(ir.SetElem(result_buffer, index, result))
|
||||||
r))))
|
|
||||||
return self.append(
|
return self.append(
|
||||||
ir.Arith(ast.Add(loc=None), index,
|
ir.Arith(ast.Add(loc=None), index,
|
||||||
ir.Constant(1, self._size_type)))
|
ir.Constant(1, self._size_type)))
|
||||||
|
@ -1700,20 +1713,29 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
rhs = self.visit(node.right)
|
rhs = self.visit(node.right)
|
||||||
|
|
||||||
shape = self.append(ir.GetAttr(lhs, "shape"))
|
# Broadcast scalars.
|
||||||
# TODO: Broadcasts; select the widest shape.
|
broadcast = False
|
||||||
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
array_arg = lhs
|
||||||
self._make_check(
|
if not builtins.is_array(lhs.type):
|
||||||
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
broadcast = True
|
||||||
lambda: self.alloc_exn(
|
array_arg = rhs
|
||||||
builtins.TException("ValueError"),
|
elif not builtins.is_array(rhs.type):
|
||||||
ir.Constant("operands could not be broadcast together",
|
broadcast = True
|
||||||
builtins.TStr())))
|
|
||||||
|
shape = self.append(ir.GetAttr(array_arg, "shape"))
|
||||||
|
|
||||||
|
if not broadcast:
|
||||||
|
rhs_shape = self.append(ir.GetAttr(rhs, "shape"))
|
||||||
|
self._make_check(
|
||||||
|
self.append(ir.Compare(ast.Eq(loc=None), shape, rhs_shape)),
|
||||||
|
lambda: self.alloc_exn(
|
||||||
|
builtins.TException("ValueError"),
|
||||||
|
ir.Constant("operands could not be broadcast together",
|
||||||
|
builtins.TStr())))
|
||||||
|
|
||||||
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
result = self._allocate_new_array(node.type.find()["elt"], shape)
|
||||||
|
func = self._get_array_binop(node.op, node.type, lhs.type, rhs.type)
|
||||||
func = self._get_array_binop(node.op, node.type, node.left.type, node.right.type)
|
|
||||||
self._invoke_arrayop(func, [result, lhs, rhs])
|
self._invoke_arrayop(func, [result, lhs, rhs])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
elif builtins.is_numeric(node.type):
|
elif builtins.is_numeric(node.type):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
|
|
|
@ -480,10 +480,10 @@ class Inferencer(algorithm.Visitor):
|
||||||
return typ.find()["num_dims"].value
|
return typ.find()["num_dims"].value
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# TODO: Broadcasting.
|
|
||||||
left_dims = num_dims(left.type)
|
left_dims = num_dims(left.type)
|
||||||
right_dims = num_dims(right.type)
|
right_dims = num_dims(right.type)
|
||||||
if left_dims != right_dims:
|
if left_dims != right_dims and left_dims != 0 and right_dims != 0:
|
||||||
|
# Mismatch (only scalar broadcast supported for now).
|
||||||
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||||
{"num_dims": left_dims}, left.loc)
|
{"num_dims": left_dims}, left.loc)
|
||||||
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
||||||
|
@ -495,16 +495,19 @@ class Inferencer(algorithm.Visitor):
|
||||||
return
|
return
|
||||||
|
|
||||||
def map_node_type(typ):
|
def map_node_type(typ):
|
||||||
if builtins.is_array(typ):
|
if not builtins.is_array(typ):
|
||||||
return typ.find()["elt"]
|
# This is a single value broadcast across the array.
|
||||||
else:
|
|
||||||
# This is (if later valid) a single value broadcast across the array.
|
|
||||||
return typ
|
return typ
|
||||||
|
return typ.find()["elt"]
|
||||||
|
|
||||||
|
# Figure out result type, handling broadcasts.
|
||||||
|
result_dims = left_dims if left_dims else right_dims
|
||||||
def map_return(typ):
|
def map_return(typ):
|
||||||
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
|
elt = builtins.TFloat() if isinstance(op, ast.Div) else typ
|
||||||
a = builtins.TArray(elt=elt, num_dims=left_dims)
|
result = builtins.TArray(elt=elt, num_dims=result_dims)
|
||||||
return (a, a, a)
|
left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt
|
||||||
|
right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt
|
||||||
|
return (result, left, right)
|
||||||
|
|
||||||
return self._coerce_numeric((left, right),
|
return self._coerce_numeric((left, right),
|
||||||
map_return=map_return,
|
map_return=map_return,
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||||
|
|
||||||
|
a = array([1, 2, 3])
|
||||||
|
|
||||||
|
c = a + 1
|
||||||
|
assert c[0] == 2
|
||||||
|
assert c[1] == 3
|
||||||
|
assert c[2] == 4
|
||||||
|
|
||||||
|
c = 1 - a
|
||||||
|
assert c[0] == 0
|
||||||
|
assert c[1] == -1
|
||||||
|
assert c[2] == -2
|
||||||
|
|
||||||
|
c = a * 1
|
||||||
|
assert c[0] == 1
|
||||||
|
assert c[1] == 2
|
||||||
|
assert c[2] == 3
|
||||||
|
|
||||||
|
c = a // 2
|
||||||
|
assert c[0] == 0
|
||||||
|
assert c[1] == 1
|
||||||
|
assert c[2] == 1
|
||||||
|
|
||||||
|
c = a ** 2
|
||||||
|
assert c[0] == 1
|
||||||
|
assert c[1] == 4
|
||||||
|
assert c[2] == 9
|
||||||
|
|
||||||
|
c = 2 ** a
|
||||||
|
assert c[0] == 2
|
||||||
|
assert c[1] == 4
|
||||||
|
assert c[2] == 8
|
||||||
|
|
||||||
|
c = a % 2
|
||||||
|
assert c[0] == 1
|
||||||
|
assert c[1] == 0
|
||||||
|
assert c[2] == 1
|
||||||
|
|
||||||
|
cf = a / 2
|
||||||
|
assert cf[0] == 0.5
|
||||||
|
assert cf[1] == 1.0
|
||||||
|
assert cf[2] == 1.5
|
||||||
|
|
||||||
|
cf2 = 2 / array([1, 2, 4])
|
||||||
|
assert cf2[0] == 2.0
|
||||||
|
assert cf2[1] == 1.0
|
||||||
|
assert cf2[2] == 0.5
|
||||||
|
|
||||||
|
d = array([[1, 2], [3, 4]])
|
||||||
|
e = d + 1
|
||||||
|
assert e[0][0] == 2
|
||||||
|
assert e[0][1] == 3
|
||||||
|
assert e[1][0] == 4
|
||||||
|
assert e[1][1] == 5
|
Loading…
Reference in New Issue