forked from M-Labs/artiq
1
0
Fork 0

compiler: Iteration for 1D ndarrays

This commit is contained in:
David Nadlinger 2020-07-26 01:33:52 +01:00
parent bc17bb4d1a
commit c95a978ab6
3 changed files with 18 additions and 3 deletions

View File

@ -510,7 +510,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
def iterable_get(self, value, index):
# Assuming the value is within bounds.
if builtins.is_listish(value.type):
if builtins.is_array(value.type):
# Scalar indexing into ndarray.
if value.type.find()["num_dims"].value > 1:
raise NotImplementedError
else:
buffer = self.append(ir.GetAttr(value, "buffer"))
return self.append(ir.GetElem(buffer, index))
elif builtins.is_listish(value.type):
return self.append(ir.GetElem(value, index))
elif builtins.is_range(value.type):
start = self.append(ir.GetAttr(value, "start"))

View File

@ -208,6 +208,14 @@ class Inferencer(algorithm.Visitor):
if builtins.is_bytes(collection.type) or builtins.is_bytearray(collection.type):
self._unify(element.type, builtins.get_iterable_elt(collection.type),
element.loc, None)
elif builtins.is_array(collection.type):
array_type = collection.type.find()
elem_dims = array_type["num_dims"].value - 1
if elem_dims > 0:
elem_type = builtins.TArray(array_type["elt"], types.TValue(elem_dims))
else:
elem_type = array_type["elt"]
self._unify(element.type, elem_type, element.loc, collection.loc)
elif builtins.is_iterable(collection.type) and not builtins.is_str(collection.type):
rhs_type = collection.type.find()
rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type})

View File

@ -4,14 +4,14 @@
ary = array([1, 2, 3])
assert len(ary) == 3
assert ary.shape == (3,)
# FIXME: Implement ndarray indexing
# assert [x*x for x in ary] == [1, 4, 9]
assert [x * x for x in ary] == [1, 4, 9]
# Reassign to an existing value to disambiguate type of empty array.
empty_array = array([1])
empty_array = array([])
assert len(empty_array) == 0
assert empty_array.shape == (0,)
assert [x * x for x in empty_array] == []
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
assert len(matrix) == 2