forked from M-Labs/artiq
compiler: Properly implement NumPy array slicing
Strided slicing of one-dimensional arrays (i.e. with non-trivial steps) might have previously been working, but would have had different semantics, as all slices were copies rather than a view into the original data. Fixing this in the future will require adding support for an index stride field/tuple to our array representation (and all the associated indexing logic). GitHub: Fixes #1627.
This commit is contained in:
parent
557671b7db
commit
c707ccf7d7
|
@ -1116,7 +1116,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
_readable_name(index))))
|
||||
if self.current_assign is None:
|
||||
return indexed
|
||||
else: # Slice
|
||||
else:
|
||||
# This is a slice. The endpoint checking logic is the same for both lists
|
||||
# and NumPy arrays, but the actual implementations differ – while slices of
|
||||
# built-in lists are always copies in Python, they are views sharing the
|
||||
# same backing storage in NumPy.
|
||||
length = self.iterable_len(value, node.slice.type)
|
||||
|
||||
if node.slice.lower is not None:
|
||||
|
@ -1141,91 +1145,127 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
mapped_stop_index = self._map_index(length, stop_index, one_past_the_end=True,
|
||||
loc=node.begin_loc)
|
||||
|
||||
if node.slice.step is not None:
|
||||
try:
|
||||
old_assign, self.current_assign = self.current_assign, None
|
||||
step = self.visit(node.slice.step)
|
||||
finally:
|
||||
self.current_assign = old_assign
|
||||
if builtins.is_array(node.type):
|
||||
# To implement strided slicing with the proper NumPy reference
|
||||
# semantics, the pointer/length array representation will need to be
|
||||
# extended by another field to hold a variable stride.
|
||||
assert node.slice.step is None, (
|
||||
"array slices with non-trivial step "
|
||||
"should have been disallowed during type inference")
|
||||
|
||||
# One-dimensionally slicing an array only affects the outermost
|
||||
# dimension.
|
||||
shape = self.append(ir.GetAttr(value, "shape"))
|
||||
lengths = [
|
||||
self.append(ir.GetAttr(shape, i))
|
||||
for i in range(len(shape.type.elts))
|
||||
]
|
||||
|
||||
# Compute outermost length – zero for "backwards" indices.
|
||||
raw_len = self.append(
|
||||
ir.Arith(ast.Sub(loc=None), mapped_stop_index, mapped_start_index))
|
||||
is_neg_len = self.append(
|
||||
ir.Compare(ast.Lt(loc=None), raw_len, ir.Constant(0, raw_len.type)))
|
||||
outer_len = self.append(
|
||||
ir.Select(is_neg_len, ir.Constant(0, raw_len.type), raw_len))
|
||||
new_shape = self._make_array_shape([outer_len] + lengths[1:])
|
||||
|
||||
# Offset buffer pointer by start index (times stride for inner dims).
|
||||
stride = reduce(
|
||||
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||
lengths[1:], ir.Constant(1, lengths[0].type))
|
||||
offset = self.append(
|
||||
ir.Arith(ast.Mult(loc=None), stride, mapped_start_index))
|
||||
buffer = self.append(ir.GetAttr(value, "buffer"))
|
||||
new_buffer = self.append(ir.Offset(buffer, offset))
|
||||
|
||||
return self.append(ir.Alloc([new_buffer, new_shape], node.type))
|
||||
else:
|
||||
if node.slice.step is not None:
|
||||
try:
|
||||
old_assign, self.current_assign = self.current_assign, None
|
||||
step = self.visit(node.slice.step)
|
||||
finally:
|
||||
self.current_assign = old_assign
|
||||
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.NotEq(loc=None), step, ir.Constant(0, step.type))),
|
||||
lambda: self.alloc_exn(builtins.TException("ValueError"),
|
||||
ir.Constant("step cannot be zero", builtins.TStr())),
|
||||
loc=node.slice.step.loc)
|
||||
else:
|
||||
step = ir.Constant(1, node.slice.type)
|
||||
counting_up = self.append(ir.Compare(ast.Gt(loc=None), step,
|
||||
ir.Constant(0, step.type)))
|
||||
|
||||
unstepped_size = self.append(ir.Arith(ast.Sub(loc=None),
|
||||
mapped_stop_index, mapped_start_index))
|
||||
slice_size_a = self.append(ir.Arith(ast.FloorDiv(loc=None), unstepped_size, step))
|
||||
slice_size_b = self.append(ir.Arith(ast.Mod(loc=None), unstepped_size, step))
|
||||
rem_not_empty = self.append(ir.Compare(ast.NotEq(loc=None), slice_size_b,
|
||||
ir.Constant(0, slice_size_b.type)))
|
||||
slice_size_c = self.append(ir.Arith(ast.Add(loc=None), slice_size_a,
|
||||
ir.Constant(1, slice_size_a.type)))
|
||||
slice_size = self.append(ir.Select(rem_not_empty,
|
||||
slice_size_c, slice_size_a,
|
||||
name="slice.size"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.NotEq(loc=None), step, ir.Constant(0, step.type))),
|
||||
lambda: self.alloc_exn(builtins.TException("ValueError"),
|
||||
ir.Constant("step cannot be zero", builtins.TStr())),
|
||||
loc=node.slice.step.loc)
|
||||
else:
|
||||
step = ir.Constant(1, node.slice.type)
|
||||
counting_up = self.append(ir.Compare(ast.Gt(loc=None), step,
|
||||
ir.Constant(0, step.type)))
|
||||
self.append(ir.Compare(ast.LtE(loc=None), slice_size, length)),
|
||||
lambda slice_size, length: self.alloc_exn(builtins.TException("ValueError"),
|
||||
ir.Constant("slice size {0} is larger than iterable length {1}",
|
||||
builtins.TStr()),
|
||||
slice_size, length),
|
||||
params=[slice_size, length],
|
||||
loc=node.slice.loc)
|
||||
|
||||
unstepped_size = self.append(ir.Arith(ast.Sub(loc=None),
|
||||
mapped_stop_index, mapped_start_index))
|
||||
slice_size_a = self.append(ir.Arith(ast.FloorDiv(loc=None), unstepped_size, step))
|
||||
slice_size_b = self.append(ir.Arith(ast.Mod(loc=None), unstepped_size, step))
|
||||
rem_not_empty = self.append(ir.Compare(ast.NotEq(loc=None), slice_size_b,
|
||||
ir.Constant(0, slice_size_b.type)))
|
||||
slice_size_c = self.append(ir.Arith(ast.Add(loc=None), slice_size_a,
|
||||
ir.Constant(1, slice_size_a.type)))
|
||||
slice_size = self.append(ir.Select(rem_not_empty,
|
||||
slice_size_c, slice_size_a,
|
||||
name="slice.size"))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.LtE(loc=None), slice_size, length)),
|
||||
lambda slice_size, length: self.alloc_exn(builtins.TException("ValueError"),
|
||||
ir.Constant("slice size {0} is larger than iterable length {1}",
|
||||
builtins.TStr()),
|
||||
slice_size, length),
|
||||
params=[slice_size, length],
|
||||
loc=node.slice.loc)
|
||||
if self.current_assign is None:
|
||||
is_neg_size = self.append(ir.Compare(ast.Lt(loc=None),
|
||||
slice_size, ir.Constant(0, slice_size.type)))
|
||||
abs_slice_size = self.append(ir.Select(is_neg_size,
|
||||
ir.Constant(0, slice_size.type), slice_size))
|
||||
other_value = self.append(ir.Alloc([abs_slice_size], value.type,
|
||||
name="slice.result"))
|
||||
else:
|
||||
other_value = self.current_assign
|
||||
|
||||
if self.current_assign is None:
|
||||
is_neg_size = self.append(ir.Compare(ast.Lt(loc=None),
|
||||
slice_size, ir.Constant(0, slice_size.type)))
|
||||
abs_slice_size = self.append(ir.Select(is_neg_size,
|
||||
ir.Constant(0, slice_size.type), slice_size))
|
||||
other_value = self.append(ir.Alloc([abs_slice_size], value.type,
|
||||
name="slice.result"))
|
||||
else:
|
||||
other_value = self.current_assign
|
||||
prehead = self.current_block
|
||||
|
||||
prehead = self.current_block
|
||||
head = self.current_block = self.add_block("slice.head")
|
||||
prehead.append(ir.Branch(head))
|
||||
|
||||
head = self.current_block = self.add_block("slice.head")
|
||||
prehead.append(ir.Branch(head))
|
||||
index = self.append(ir.Phi(node.slice.type,
|
||||
name="slice.index"))
|
||||
index.add_incoming(mapped_start_index, prehead)
|
||||
other_index = self.append(ir.Phi(node.slice.type,
|
||||
name="slice.resindex"))
|
||||
other_index.add_incoming(ir.Constant(0, node.slice.type), prehead)
|
||||
|
||||
index = self.append(ir.Phi(node.slice.type,
|
||||
name="slice.index"))
|
||||
index.add_incoming(mapped_start_index, prehead)
|
||||
other_index = self.append(ir.Phi(node.slice.type,
|
||||
name="slice.resindex"))
|
||||
other_index.add_incoming(ir.Constant(0, node.slice.type), prehead)
|
||||
# Still within bounds?
|
||||
bounded_up = self.append(ir.Compare(ast.Lt(loc=None), index, mapped_stop_index))
|
||||
bounded_down = self.append(ir.Compare(ast.Gt(loc=None), index, mapped_stop_index))
|
||||
within_bounds = self.append(ir.Select(counting_up, bounded_up, bounded_down))
|
||||
|
||||
# Still within bounds?
|
||||
bounded_up = self.append(ir.Compare(ast.Lt(loc=None), index, mapped_stop_index))
|
||||
bounded_down = self.append(ir.Compare(ast.Gt(loc=None), index, mapped_stop_index))
|
||||
within_bounds = self.append(ir.Select(counting_up, bounded_up, bounded_down))
|
||||
body = self.current_block = self.add_block("slice.body")
|
||||
|
||||
body = self.current_block = self.add_block("slice.body")
|
||||
if self.current_assign is None:
|
||||
elem = self.iterable_get(value, index)
|
||||
self.append(ir.SetElem(other_value, other_index, elem))
|
||||
else:
|
||||
elem = self.append(ir.GetElem(self.current_assign, other_index))
|
||||
self.append(ir.SetElem(value, index, elem))
|
||||
|
||||
if self.current_assign is None:
|
||||
elem = self.iterable_get(value, index)
|
||||
self.append(ir.SetElem(other_value, other_index, elem))
|
||||
else:
|
||||
elem = self.append(ir.GetElem(self.current_assign, other_index))
|
||||
self.append(ir.SetElem(value, index, elem))
|
||||
next_index = self.append(ir.Arith(ast.Add(loc=None), index, step))
|
||||
index.add_incoming(next_index, body)
|
||||
next_other_index = self.append(ir.Arith(ast.Add(loc=None), other_index,
|
||||
ir.Constant(1, node.slice.type)))
|
||||
other_index.add_incoming(next_other_index, body)
|
||||
self.append(ir.Branch(head))
|
||||
|
||||
next_index = self.append(ir.Arith(ast.Add(loc=None), index, step))
|
||||
index.add_incoming(next_index, body)
|
||||
next_other_index = self.append(ir.Arith(ast.Add(loc=None), other_index,
|
||||
ir.Constant(1, node.slice.type)))
|
||||
other_index.add_incoming(next_other_index, body)
|
||||
self.append(ir.Branch(head))
|
||||
tail = self.current_block = self.add_block("slice.tail")
|
||||
head.append(ir.BranchIf(within_bounds, body, tail))
|
||||
|
||||
tail = self.current_block = self.add_block("slice.tail")
|
||||
head.append(ir.BranchIf(within_bounds, body, tail))
|
||||
|
||||
if self.current_assign is None:
|
||||
return other_value
|
||||
if self.current_assign is None:
|
||||
return other_value
|
||||
|
||||
def visit_TupleT(self, node):
|
||||
if self.current_assign is None:
|
||||
|
|
|
@ -269,6 +269,14 @@ class Inferencer(algorithm.Visitor):
|
|||
else:
|
||||
self._unify_iterable(element=node, collection=node.value)
|
||||
elif isinstance(node.slice, ast.Slice):
|
||||
if builtins.is_array(node.value.type):
|
||||
if node.slice.step is not None:
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"strided slicing not yet supported for NumPy arrays", {},
|
||||
node.slice.step.loc, [])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
self._unify(node.type, node.value.type, node.loc, node.value.loc)
|
||||
else: # ExtSlice
|
||||
pass # error emitted above
|
||||
|
|
|
@ -9,5 +9,8 @@ b = array([1, 2, 3])
|
|||
# CHECK-L: ${LINE:+1}: error: too many indices for array of dimension 1
|
||||
b[1, 2]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: strided slicing not yet supported for NumPy arrays
|
||||
b[::-1]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
||||
b.shape = (5, )
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
|
||||
a = array([0, 1, 2, 3])
|
||||
|
||||
b = a[2:3]
|
||||
assert b.shape == (1,)
|
||||
assert b[0] == 2
|
||||
b[0] = 5
|
||||
assert a[2] == 5
|
||||
|
||||
b = a[3:2]
|
||||
assert b.shape == (0,)
|
||||
|
||||
c = array([[0, 1], [2, 3]])
|
||||
|
||||
d = c[:1]
|
||||
assert d.shape == (1, 2)
|
||||
assert d[0, 0] == 0
|
||||
assert d[0, 1] == 1
|
||||
d[0, 0] = 5
|
||||
assert c[0, 0] == 5
|
||||
|
||||
d = c[1:0]
|
||||
assert d.shape == (0, 2)
|
|
@ -0,0 +1,13 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.embedding %s
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
import numpy as np
|
||||
|
||||
n = 2
|
||||
data = np.zeros((n, n))
|
||||
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
print(data[:n])
|
Loading…
Reference in New Issue