forked from M-Labs/artiq
compiler: Implement multi-dimensional indexing of arrays
This generates rather more code than necessary, but has the advantage of automatically handling incomplete multi-dimensional subscripts which still leave arrays behind.
This commit is contained in:
parent
b00ba5ece1
commit
33d931a5b7
|
@ -1092,16 +1092,31 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
finally:
|
finally:
|
||||||
self.current_assign = old_assign
|
self.current_assign = old_assign
|
||||||
|
|
||||||
length = self.iterable_len(value, index.type)
|
# For multi-dimensional indexes, just apply them sequentially. This
|
||||||
mapped_index = self._map_index(length, index,
|
# works, as they are only supported for types where we do not
|
||||||
loc=node.begin_loc)
|
# immediately need to distinguish between the Get and Set cases
|
||||||
if self.current_assign is None:
|
# (i.e. arrays, which are reference types).
|
||||||
result = self.iterable_get(value, mapped_index)
|
if types.is_tuple(index.type):
|
||||||
result.set_name("{}.at.{}".format(value.name, _readable_name(index)))
|
num_idxs = len(index.type.find().elts)
|
||||||
return result
|
indices = [
|
||||||
|
self.append(ir.GetAttr(index, i)) for i in range(num_idxs)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self.append(ir.SetElem(value, mapped_index, self.current_assign,
|
indices = [index]
|
||||||
name="{}.at.{}".format(value.name, _readable_name(index))))
|
indexed = value
|
||||||
|
for i, idx in enumerate(indices):
|
||||||
|
length = self.iterable_len(indexed, idx.type)
|
||||||
|
mapped_index = self._map_index(length, idx, loc=node.begin_loc)
|
||||||
|
if self.current_assign is None or i < len(indices) - 1:
|
||||||
|
indexed = self.iterable_get(indexed, mapped_index)
|
||||||
|
indexed.set_name("{}.at.{}".format(indexed.name,
|
||||||
|
_readable_name(idx)))
|
||||||
|
else:
|
||||||
|
self.append(ir.SetElem(indexed, mapped_index, self.current_assign,
|
||||||
|
name="{}.at.{}".format(value.name,
|
||||||
|
_readable_name(index))))
|
||||||
|
if self.current_assign is None:
|
||||||
|
return indexed
|
||||||
else: # Slice
|
else: # Slice
|
||||||
length = self.iterable_len(value, node.slice.type)
|
length = self.iterable_len(value, node.slice.type)
|
||||||
|
|
||||||
|
|
|
@ -208,10 +208,9 @@ class Inferencer(algorithm.Visitor):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
value = node.value
|
value = node.value
|
||||||
if types.is_tuple(value.type):
|
if types.is_tuple(value.type):
|
||||||
diag = diagnostic.Diagnostic("error",
|
for elt in value.type.find().elts:
|
||||||
"multi-dimensional slices are not supported", {},
|
self._unify(elt, builtins.TInt(),
|
||||||
node.loc, [])
|
value.loc, None)
|
||||||
self.engine.process(diag)
|
|
||||||
else:
|
else:
|
||||||
self._unify(value.type, builtins.TInt(),
|
self._unify(value.type, builtins.TInt(),
|
||||||
value.loc, None)
|
value.loc, None)
|
||||||
|
@ -237,12 +236,39 @@ class Inferencer(algorithm.Visitor):
|
||||||
def visit_SubscriptT(self, node):
|
def visit_SubscriptT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if isinstance(node.slice, ast.Index):
|
if isinstance(node.slice, ast.Index):
|
||||||
self._unify_iterable(element=node, collection=node.value)
|
if types.is_tuple(node.slice.value.type):
|
||||||
|
if not builtins.is_array(node.value.type):
|
||||||
|
diag = diagnostic.Diagnostic(
|
||||||
|
"error",
|
||||||
|
"multi-dimensional slices only supported for arrays, not {type}",
|
||||||
|
{"type": types.TypePrinter().name(node.value.type)},
|
||||||
|
node.loc, [])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
num_idxs = len(node.slice.value.type.find().elts)
|
||||||
|
array_type = node.value.type.find()
|
||||||
|
num_dims = array_type["num_dims"].value
|
||||||
|
remaining_dims = num_dims - num_idxs
|
||||||
|
if remaining_dims < 0:
|
||||||
|
diag = diagnostic.Diagnostic(
|
||||||
|
"error",
|
||||||
|
"too many indices for array of dimension {num_dims}",
|
||||||
|
{"num_dims": num_dims}, node.slice.loc, [])
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
if remaining_dims == 0:
|
||||||
|
self._unify(node.type, array_type["elt"], node.loc,
|
||||||
|
node.value.loc)
|
||||||
|
else:
|
||||||
|
self._unify(
|
||||||
|
node.type,
|
||||||
|
builtins.TArray(array_type["elt"], remaining_dims))
|
||||||
|
else:
|
||||||
|
self._unify_iterable(element=node, collection=node.value)
|
||||||
elif isinstance(node.slice, ast.Slice):
|
elif isinstance(node.slice, ast.Slice):
|
||||||
self._unify(node.type, node.value.type,
|
self._unify(node.type, node.value.type, node.loc, node.value.loc)
|
||||||
node.loc, node.value.loc)
|
else: # ExtSlice
|
||||||
else: # ExtSlice
|
pass # error emitted above
|
||||||
pass # error emitted above
|
|
||||||
|
|
||||||
def visit_IfExpT(self, node):
|
def visit_IfExpT(self, node):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
|
@ -5,5 +5,9 @@
|
||||||
a = array()
|
a = array()
|
||||||
|
|
||||||
b = array([1, 2, 3])
|
b = array([1, 2, 3])
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+1}: error: too many indices for array of dimension 1
|
||||||
|
b[1, 2]
|
||||||
|
|
||||||
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
||||||
b.shape = (5, )
|
b.shape = (5, )
|
||||||
|
|
|
@ -26,21 +26,21 @@ assert matrix.shape == (2, 3)
|
||||||
# FIXME: Need to decide on a solution for array comparisons —
|
# FIXME: Need to decide on a solution for array comparisons —
|
||||||
# NumPy returns an array of bools!
|
# NumPy returns an array of bools!
|
||||||
# assert [x for x in matrix] == [array([1.0, 2.0, 3.0]), array([4.0, 5.0, 6.0])]
|
# assert [x for x in matrix] == [array([1.0, 2.0, 3.0]), array([4.0, 5.0, 6.0])]
|
||||||
assert matrix[0][0] == 1.0
|
assert matrix[0, 0] == 1.0
|
||||||
assert matrix[0][1] == 2.0
|
assert matrix[0, 1] == 2.0
|
||||||
assert matrix[0][2] == 3.0
|
assert matrix[0, 2] == 3.0
|
||||||
assert matrix[1][0] == 4.0
|
assert matrix[1, 0] == 4.0
|
||||||
assert matrix[1][1] == 5.0
|
assert matrix[1, 1] == 5.0
|
||||||
assert matrix[1][2] == 6.0
|
assert matrix[1, 2] == 6.0
|
||||||
|
|
||||||
matrix[0][0] = 7.0
|
matrix[0, 0] = 7.0
|
||||||
matrix[1][1] = 8.0
|
matrix[1, 1] = 8.0
|
||||||
assert matrix[0][0] == 7.0
|
assert matrix[0, 0] == 7.0
|
||||||
assert matrix[0][1] == 2.0
|
assert matrix[0, 1] == 2.0
|
||||||
assert matrix[0][2] == 3.0
|
assert matrix[0, 2] == 3.0
|
||||||
assert matrix[1][0] == 4.0
|
assert matrix[1, 0] == 4.0
|
||||||
assert matrix[1][1] == 8.0
|
assert matrix[1, 1] == 8.0
|
||||||
assert matrix[1][2] == 6.0
|
assert matrix[1, 2] == 6.0
|
||||||
|
|
||||||
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
||||||
assert len(three_tensor) == 1
|
assert len(three_tensor) == 1
|
||||||
|
|
Loading…
Reference in New Issue