forked from M-Labs/artiq
1
0
Fork 0

compiler: Properly implement NumPy array slicing

Strided slicing of one-dimensional arrays (i.e. with non-trivial
steps) might have previously been working, but would have had
different semantics, as all slices were copies rather than a view
into the original data.

Fixing this in the future will require adding support for an index
stride field/tuple to our array representation (and all the
associated indexing logic).

GitHub: Fixes #1627.
This commit is contained in:
David Nadlinger 2021-03-14 19:57:01 +00:00
parent 557671b7db
commit c707ccf7d7
5 changed files with 162 additions and 74 deletions

View File

@ -1116,7 +1116,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
_readable_name(index))))
if self.current_assign is None:
return indexed
else: # Slice
else:
# This is a slice. The endpoint checking logic is the same for both lists
# and NumPy arrays, but the actual implementations differ while slices of
# built-in lists are always copies in Python, they are views sharing the
# same backing storage in NumPy.
length = self.iterable_len(value, node.slice.type)
if node.slice.lower is not None:
@ -1141,6 +1145,42 @@ class ARTIQIRGenerator(algorithm.Visitor):
mapped_stop_index = self._map_index(length, stop_index, one_past_the_end=True,
loc=node.begin_loc)
if builtins.is_array(node.type):
# To implement strided slicing with the proper NumPy reference
# semantics, the pointer/length array representation will need to be
# extended by another field to hold a variable stride.
assert node.slice.step is None, (
"array slices with non-trivial step "
"should have been disallowed during type inference")
# One-dimensionally slicing an array only affects the outermost
# dimension.
shape = self.append(ir.GetAttr(value, "shape"))
lengths = [
self.append(ir.GetAttr(shape, i))
for i in range(len(shape.type.elts))
]
# Compute outermost length zero for "backwards" indices.
raw_len = self.append(
ir.Arith(ast.Sub(loc=None), mapped_stop_index, mapped_start_index))
is_neg_len = self.append(
ir.Compare(ast.Lt(loc=None), raw_len, ir.Constant(0, raw_len.type)))
outer_len = self.append(
ir.Select(is_neg_len, ir.Constant(0, raw_len.type), raw_len))
new_shape = self._make_array_shape([outer_len] + lengths[1:])
# Offset buffer pointer by start index (times stride for inner dims).
stride = reduce(
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
lengths[1:], ir.Constant(1, lengths[0].type))
offset = self.append(
ir.Arith(ast.Mult(loc=None), stride, mapped_start_index))
buffer = self.append(ir.GetAttr(value, "buffer"))
new_buffer = self.append(ir.Offset(buffer, offset))
return self.append(ir.Alloc([new_buffer, new_shape], node.type))
else:
if node.slice.step is not None:
try:
old_assign, self.current_assign = self.current_assign, None

View File

@ -269,6 +269,14 @@ class Inferencer(algorithm.Visitor):
else:
self._unify_iterable(element=node, collection=node.value)
elif isinstance(node.slice, ast.Slice):
if builtins.is_array(node.value.type):
if node.slice.step is not None:
diag = diagnostic.Diagnostic(
"error",
"strided slicing not yet supported for NumPy arrays", {},
node.slice.step.loc, [])
self.engine.process(diag)
return
self._unify(node.type, node.value.type, node.loc, node.value.loc)
else: # ExtSlice
pass # error emitted above

View File

@ -9,5 +9,8 @@ b = array([1, 2, 3])
# CHECK-L: ${LINE:+1}: error: too many indices for array of dimension 1
b[1, 2]
# CHECK-L: ${LINE:+1}: error: strided slicing not yet supported for NumPy arrays
b[::-1]
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
b.shape = (5, )

View File

@ -0,0 +1,24 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
a = array([0, 1, 2, 3])
b = a[2:3]
assert b.shape == (1,)
assert b[0] == 2
b[0] = 5
assert a[2] == 5
b = a[3:2]
assert b.shape == (0,)
c = array([[0, 1], [2, 3]])
d = c[:1]
assert d.shape == (1, 2)
assert d[0, 0] == 0
assert d[0, 1] == 1
d[0, 0] = 5
assert c[0, 0] == 5
d = c[1:0]
assert d.shape == (0, 2)

View File

@ -0,0 +1,13 @@
# RUN: %python -m artiq.compiler.testbench.embedding %s
from artiq.language.core import *
from artiq.language.types import *
import numpy as np
n = 2
data = np.zeros((n, n))
@kernel
def entrypoint():
print(data[:n])