forked from M-Labs/artiq
parent
5b5db1433b
commit
96692791cf
|
@ -390,6 +390,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
def visit_AugAssign(self, node):
|
||||
lhs = self.visit(node.target)
|
||||
rhs = self.visit(node.value)
|
||||
|
||||
if builtins.is_array(lhs.type):
|
||||
name = type(node.op).__name__
|
||||
def make_op(l, r):
|
||||
return self.append(ir.Arith(node.op, l, r))
|
||||
self._broadcast_binop(name, make_op, lhs.type, lhs, rhs, assign_to_lhs=True)
|
||||
return
|
||||
|
||||
value = self.append(ir.Arith(node.op, lhs, rhs))
|
||||
try:
|
||||
self.current_assign = value
|
||||
|
@ -1715,7 +1723,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
return self.append(ir.Alloc([result_buffer, shape], node.type))
|
||||
return self.append(ir.GetElem(result_buffer, ir.Constant(0, self._size_type)))
|
||||
|
||||
def _broadcast_binop(self, name, make_op, result_type, lhs, rhs):
|
||||
def _broadcast_binop(self, name, make_op, result_type, lhs, rhs, assign_to_lhs):
|
||||
# Broadcast scalars (broadcasting higher dimensions is not yet allowed in the
|
||||
# language).
|
||||
broadcast = False
|
||||
|
@ -1736,7 +1744,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
builtins.TException("ValueError"),
|
||||
ir.Constant("operands could not be broadcast together",
|
||||
builtins.TStr())))
|
||||
|
||||
if assign_to_lhs:
|
||||
result = lhs
|
||||
else:
|
||||
elt = result_type.find()["elt"]
|
||||
result, _ = self._allocate_new_array(elt, shape)
|
||||
func = self._get_array_elementwise_binop(name, make_op, result_type, lhs.type,
|
||||
|
@ -1753,7 +1763,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
name = type(node.op).__name__
|
||||
def make_op(l, r):
|
||||
return self.append(ir.Arith(node.op, l, r))
|
||||
return self._broadcast_binop(name, make_op, node.type, lhs, rhs)
|
||||
return self._broadcast_binop(name, make_op, node.type, lhs, rhs,
|
||||
assign_to_lhs=False)
|
||||
elif builtins.is_numeric(node.type):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
|
@ -2446,7 +2457,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
self._invoke_arrayop(func, [result, args[0]])
|
||||
insn = result
|
||||
elif len(args) == 2:
|
||||
insn = self._broadcast_binop(name, make_call, node.type, *args)
|
||||
insn = self._broadcast_binop(name, make_call, node.type, *args,
|
||||
assign_to_lhs=False)
|
||||
else:
|
||||
assert False, "Broadcasting for {} arguments not implemented".format(len)
|
||||
else:
|
||||
|
|
|
@ -8,31 +8,61 @@ assert c[0] == 5
|
|||
assert c[1] == 7
|
||||
assert c[2] == 9
|
||||
|
||||
c += a
|
||||
assert c[0] == 6
|
||||
assert c[1] == 9
|
||||
assert c[2] == 12
|
||||
|
||||
c = b - a
|
||||
assert c[0] == 3
|
||||
assert c[1] == 3
|
||||
assert c[2] == 3
|
||||
|
||||
c -= a
|
||||
assert c[0] == 2
|
||||
assert c[1] == 1
|
||||
assert c[2] == 0
|
||||
|
||||
c = a * b
|
||||
assert c[0] == 4
|
||||
assert c[1] == 10
|
||||
assert c[2] == 18
|
||||
|
||||
c *= a
|
||||
assert c[0] == 4
|
||||
assert c[1] == 20
|
||||
assert c[2] == 54
|
||||
|
||||
c = b // a
|
||||
assert c[0] == 4
|
||||
assert c[1] == 2
|
||||
assert c[2] == 2
|
||||
|
||||
c //= a
|
||||
assert c[0] == 4
|
||||
assert c[1] == 1
|
||||
assert c[2] == 0
|
||||
|
||||
c = a ** b
|
||||
assert c[0] == 1
|
||||
assert c[1] == 32
|
||||
assert c[2] == 729
|
||||
|
||||
c **= a
|
||||
assert c[0] == 1
|
||||
assert c[1] == 1024
|
||||
assert c[2] == 387420489
|
||||
|
||||
c = b % a
|
||||
assert c[0] == 0
|
||||
assert c[1] == 1
|
||||
assert c[2] == 0
|
||||
|
||||
c %= a
|
||||
assert c[0] == 0
|
||||
assert c[1] == 1
|
||||
assert c[2] == 0
|
||||
|
||||
cf = b / a
|
||||
assert cf[0] == 4.0
|
||||
assert cf[1] == 2.5
|
||||
|
@ -43,6 +73,16 @@ assert cf2[0] == 5.0
|
|||
assert cf2[1] == 4.5
|
||||
assert cf2[2] == 5.0
|
||||
|
||||
cf2 += a
|
||||
assert cf2[0] == 6.0
|
||||
assert cf2[1] == 6.5
|
||||
assert cf2[2] == 8.0
|
||||
|
||||
cf /= a
|
||||
assert cf[0] == 4.0
|
||||
assert cf[1] == 1.25
|
||||
assert cf[2] == 2.0 / 3.0
|
||||
|
||||
d = array([[1, 2], [3, 4]])
|
||||
e = array([[5, 6], [7, 8]])
|
||||
f = d + e
|
||||
|
@ -50,3 +90,9 @@ assert f[0][0] == 6
|
|||
assert f[0][1] == 8
|
||||
assert f[1][0] == 10
|
||||
assert f[1][1] == 12
|
||||
|
||||
f += d
|
||||
assert f[0][0] == 7
|
||||
assert f[0][1] == 10
|
||||
assert f[1][0] == 13
|
||||
assert f[1][1] == 16
|
||||
|
|
Loading…
Reference in New Issue