forked from M-Labs/artiq
compiler: Implement multi-dimensional indexing of arrays
This generates rather more code than necessary, but has the advantage of automatically handling incomplete multi-dimensional subscripts which still leave arrays behind.
This commit is contained in:
parent
b00ba5ece1
commit
33d931a5b7
|
@ -1092,16 +1092,31 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
finally:
|
||||
self.current_assign = old_assign
|
||||
|
||||
length = self.iterable_len(value, index.type)
|
||||
mapped_index = self._map_index(length, index,
|
||||
loc=node.begin_loc)
|
||||
if self.current_assign is None:
|
||||
result = self.iterable_get(value, mapped_index)
|
||||
result.set_name("{}.at.{}".format(value.name, _readable_name(index)))
|
||||
return result
|
||||
# For multi-dimensional indexes, just apply them sequentially. This
|
||||
# works, as they are only supported for types where we do not
|
||||
# immediately need to distinguish between the Get and Set cases
|
||||
# (i.e. arrays, which are reference types).
|
||||
if types.is_tuple(index.type):
|
||||
num_idxs = len(index.type.find().elts)
|
||||
indices = [
|
||||
self.append(ir.GetAttr(index, i)) for i in range(num_idxs)
|
||||
]
|
||||
else:
|
||||
self.append(ir.SetElem(value, mapped_index, self.current_assign,
|
||||
name="{}.at.{}".format(value.name, _readable_name(index))))
|
||||
indices = [index]
|
||||
indexed = value
|
||||
for i, idx in enumerate(indices):
|
||||
length = self.iterable_len(indexed, idx.type)
|
||||
mapped_index = self._map_index(length, idx, loc=node.begin_loc)
|
||||
if self.current_assign is None or i < len(indices) - 1:
|
||||
indexed = self.iterable_get(indexed, mapped_index)
|
||||
indexed.set_name("{}.at.{}".format(indexed.name,
|
||||
_readable_name(idx)))
|
||||
else:
|
||||
self.append(ir.SetElem(indexed, mapped_index, self.current_assign,
|
||||
name="{}.at.{}".format(value.name,
|
||||
_readable_name(index))))
|
||||
if self.current_assign is None:
|
||||
return indexed
|
||||
else: # Slice
|
||||
length = self.iterable_len(value, node.slice.type)
|
||||
|
||||
|
|
|
@ -208,10 +208,9 @@ class Inferencer(algorithm.Visitor):
|
|||
self.generic_visit(node)
|
||||
value = node.value
|
||||
if types.is_tuple(value.type):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"multi-dimensional slices are not supported", {},
|
||||
node.loc, [])
|
||||
self.engine.process(diag)
|
||||
for elt in value.type.find().elts:
|
||||
self._unify(elt, builtins.TInt(),
|
||||
value.loc, None)
|
||||
else:
|
||||
self._unify(value.type, builtins.TInt(),
|
||||
value.loc, None)
|
||||
|
@ -237,10 +236,37 @@ class Inferencer(algorithm.Visitor):
|
|||
def visit_SubscriptT(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.slice, ast.Index):
|
||||
if types.is_tuple(node.slice.value.type):
|
||||
if not builtins.is_array(node.value.type):
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"multi-dimensional slices only supported for arrays, not {type}",
|
||||
{"type": types.TypePrinter().name(node.value.type)},
|
||||
node.loc, [])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
num_idxs = len(node.slice.value.type.find().elts)
|
||||
array_type = node.value.type.find()
|
||||
num_dims = array_type["num_dims"].value
|
||||
remaining_dims = num_dims - num_idxs
|
||||
if remaining_dims < 0:
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"too many indices for array of dimension {num_dims}",
|
||||
{"num_dims": num_dims}, node.slice.loc, [])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
if remaining_dims == 0:
|
||||
self._unify(node.type, array_type["elt"], node.loc,
|
||||
node.value.loc)
|
||||
else:
|
||||
self._unify(
|
||||
node.type,
|
||||
builtins.TArray(array_type["elt"], remaining_dims))
|
||||
else:
|
||||
self._unify_iterable(element=node, collection=node.value)
|
||||
elif isinstance(node.slice, ast.Slice):
|
||||
self._unify(node.type, node.value.type,
|
||||
node.loc, node.value.loc)
|
||||
self._unify(node.type, node.value.type, node.loc, node.value.loc)
|
||||
else: # ExtSlice
|
||||
pass # error emitted above
|
||||
|
||||
|
|
|
@ -5,5 +5,9 @@
|
|||
a = array()
|
||||
|
||||
b = array([1, 2, 3])
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: too many indices for array of dimension 1
|
||||
b[1, 2]
|
||||
|
||||
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
||||
b.shape = (5, )
|
||||
|
|
|
@ -26,21 +26,21 @@ assert matrix.shape == (2, 3)
|
|||
# FIXME: Need to decide on a solution for array comparisons —
|
||||
# NumPy returns an array of bools!
|
||||
# assert [x for x in matrix] == [array([1.0, 2.0, 3.0]), array([4.0, 5.0, 6.0])]
|
||||
assert matrix[0][0] == 1.0
|
||||
assert matrix[0][1] == 2.0
|
||||
assert matrix[0][2] == 3.0
|
||||
assert matrix[1][0] == 4.0
|
||||
assert matrix[1][1] == 5.0
|
||||
assert matrix[1][2] == 6.0
|
||||
assert matrix[0, 0] == 1.0
|
||||
assert matrix[0, 1] == 2.0
|
||||
assert matrix[0, 2] == 3.0
|
||||
assert matrix[1, 0] == 4.0
|
||||
assert matrix[1, 1] == 5.0
|
||||
assert matrix[1, 2] == 6.0
|
||||
|
||||
matrix[0][0] = 7.0
|
||||
matrix[1][1] = 8.0
|
||||
assert matrix[0][0] == 7.0
|
||||
assert matrix[0][1] == 2.0
|
||||
assert matrix[0][2] == 3.0
|
||||
assert matrix[1][0] == 4.0
|
||||
assert matrix[1][1] == 8.0
|
||||
assert matrix[1][2] == 6.0
|
||||
matrix[0, 0] = 7.0
|
||||
matrix[1, 1] = 8.0
|
||||
assert matrix[0, 0] == 7.0
|
||||
assert matrix[0, 1] == 2.0
|
||||
assert matrix[0, 2] == 3.0
|
||||
assert matrix[1, 0] == 4.0
|
||||
assert matrix[1, 1] == 8.0
|
||||
assert matrix[1, 2] == 6.0
|
||||
|
||||
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
||||
assert len(three_tensor) == 1
|
||||
|
|
Loading…
Reference in New Issue