forked from M-Labs/artiq
compiler: Iteration for 1D ndarrays
This commit is contained in:
parent
bc17bb4d1a
commit
c95a978ab6
|
@ -510,7 +510,14 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
|
||||
def iterable_get(self, value, index):
|
||||
# Assuming the value is within bounds.
|
||||
if builtins.is_listish(value.type):
|
||||
if builtins.is_array(value.type):
|
||||
# Scalar indexing into ndarray.
|
||||
if value.type.find()["num_dims"].value > 1:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
buffer = self.append(ir.GetAttr(value, "buffer"))
|
||||
return self.append(ir.GetElem(buffer, index))
|
||||
elif builtins.is_listish(value.type):
|
||||
return self.append(ir.GetElem(value, index))
|
||||
elif builtins.is_range(value.type):
|
||||
start = self.append(ir.GetAttr(value, "start"))
|
||||
|
|
|
@ -208,6 +208,14 @@ class Inferencer(algorithm.Visitor):
|
|||
if builtins.is_bytes(collection.type) or builtins.is_bytearray(collection.type):
|
||||
self._unify(element.type, builtins.get_iterable_elt(collection.type),
|
||||
element.loc, None)
|
||||
elif builtins.is_array(collection.type):
|
||||
array_type = collection.type.find()
|
||||
elem_dims = array_type["num_dims"].value - 1
|
||||
if elem_dims > 0:
|
||||
elem_type = builtins.TArray(array_type["elt"], types.TValue(elem_dims))
|
||||
else:
|
||||
elem_type = array_type["elt"]
|
||||
self._unify(element.type, elem_type, element.loc, collection.loc)
|
||||
elif builtins.is_iterable(collection.type) and not builtins.is_str(collection.type):
|
||||
rhs_type = collection.type.find()
|
||||
rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type})
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
ary = array([1, 2, 3])
|
||||
assert len(ary) == 3
|
||||
assert ary.shape == (3,)
|
||||
# FIXME: Implement ndarray indexing
|
||||
# assert [x*x for x in ary] == [1, 4, 9]
|
||||
assert [x * x for x in ary] == [1, 4, 9]
|
||||
|
||||
# Reassign to an existing value to disambiguate type of empty array.
|
||||
empty_array = array([1])
|
||||
empty_array = array([])
|
||||
assert len(empty_array) == 0
|
||||
assert empty_array.shape == (0,)
|
||||
assert [x * x for x in empty_array] == []
|
||||
|
||||
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
assert len(matrix) == 2
|
||||
|
|
Loading…
Reference in New Issue