forked from M-Labs/artiq
compiler: Properly implement NumPy array slicing
Strided slicing of one-dimensional arrays (i.e. with non-trivial steps) might have previously been working, but would have had different semantics, as all slices were copies rather than a view into the original data. Fixing this in the future will require adding support for an index stride field/tuple to our array representation (and all the associated indexing logic). GitHub: Fixes #1627.
This commit is contained in:
parent
8a892af244
commit
925014689e
|
@ -1116,7 +1116,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
_readable_name(index))))
|
||||
if self.current_assign is None:
|
||||
return indexed
|
||||
else: # Slice
|
||||
else:
|
||||
# This is a slice. The endpoint checking logic is the same for both lists
|
||||
# and NumPy arrays, but the actual implementations differ – while slices of
|
||||
# built-in lists are always copies in Python, they are views sharing the
|
||||
# same backing storage in NumPy.
|
||||
length = self.iterable_len(value, node.slice.type)
|
||||
|
||||
if node.slice.lower is not None:
|
||||
|
@ -1141,6 +1145,42 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
mapped_stop_index = self._map_index(length, stop_index, one_past_the_end=True,
|
||||
loc=node.begin_loc)
|
||||
|
||||
if builtins.is_array(node.type):
|
||||
# To implement strided slicing with the proper NumPy reference
|
||||
# semantics, the pointer/length array representation will need to be
|
||||
# extended by another field to hold a variable stride.
|
||||
assert node.slice.step is None, (
|
||||
"array slices with non-trivial step "
|
||||
"should have been disallowed during type inference")
|
||||
|
||||
# One-dimensionally slicing an array only affects the outermost
|
||||
# dimension.
|
||||
shape = self.append(ir.GetAttr(value, "shape"))
|
||||
lengths = [
|
||||
self.append(ir.GetAttr(shape, i))
|
||||
for i in range(len(shape.type.elts))
|
||||
]
|
||||
|
||||
# Compute outermost length – zero for "backwards" indices.
|
||||
raw_len = self.append(
|
||||
ir.Arith(ast.Sub(loc=None), mapped_stop_index, mapped_start_index))
|
||||
is_neg_len = self.append(
|
||||
ir.Compare(ast.Lt(loc=None), raw_len, ir.Constant(0, raw_len.type)))
|
||||
outer_len = self.append(
|
||||
ir.Select(is_neg_len, ir.Constant(0, raw_len.type), raw_len))
|
||||
new_shape = self._make_array_shape([outer_len] + lengths[1:])
|
||||
|
||||
# Offset buffer pointer by start index (times stride for inner dims).
|
||||
stride = reduce(
|
||||
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||
lengths[1:], ir.Constant(1, lengths[0].type))
|
||||
offset = self.append(
|
||||
ir.Arith(ast.Mult(loc=None), stride, mapped_start_index))
|
||||
buffer = self.append(ir.GetAttr(value, "buffer"))
|
||||
new_buffer = self.append(ir.Offset(buffer, offset))
|
||||
|
||||
return self.append(ir.Alloc([new_buffer, new_shape], node.type))
|
||||
else:
|
||||
if node.slice.step is not None:
|
||||
try:
|
||||
old_assign, self.current_assign = self.current_assign, None
|
||||
|
|
|
@ -269,6 +269,14 @@ class Inferencer(algorithm.Visitor):
|
|||
else:
|
||||
self._unify_iterable(element=node, collection=node.value)
|
||||
elif isinstance(node.slice, ast.Slice):
|
||||
if builtins.is_array(node.value.type):
|
||||
if node.slice.step is not None:
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"strided slicing not yet supported for NumPy arrays", {},
|
||||
node.slice.step.loc, [])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
self._unify(node.type, node.value.type, node.loc, node.value.loc)
|
||||
else: # ExtSlice
|
||||
pass # error emitted above
|
||||
|
|
|
@ -9,5 +9,8 @@ b = array([1, 2, 3])
|
|||
# CHECK-L: ${LINE:+1}: error: too many indices for array of dimension 1
|
||||
b[1, 2]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: strided slicing not yet supported for NumPy arrays
|
||||
b[::-1]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
||||
b.shape = (5, )
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
|
||||
a = array([0, 1, 2, 3])
|
||||
|
||||
b = a[2:3]
|
||||
assert b.shape == (1,)
|
||||
assert b[0] == 2
|
||||
b[0] = 5
|
||||
assert a[2] == 5
|
||||
|
||||
b = a[3:2]
|
||||
assert b.shape == (0,)
|
||||
|
||||
c = array([[0, 1], [2, 3]])
|
||||
|
||||
d = c[:1]
|
||||
assert d.shape == (1, 2)
|
||||
assert d[0, 0] == 0
|
||||
assert d[0, 1] == 1
|
||||
d[0, 0] = 5
|
||||
assert c[0, 0] == 5
|
||||
|
||||
d = c[1:0]
|
||||
assert d.shape == (0, 2)
|
|
@ -0,0 +1,13 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.embedding %s
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
import numpy as np
|
||||
|
||||
n = 2
|
||||
data = np.zeros((n, n))
|
||||
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
print(data[:n])
|
Loading…
Reference in New Issue