diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 730ea7f0a..afd0da1a9 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -510,7 +510,14 @@ class ARTIQIRGenerator(algorithm.Visitor): def iterable_get(self, value, index): # Assuming the value is within bounds. - if builtins.is_listish(value.type): + if builtins.is_array(value.type): + # Scalar indexing into ndarray. + if value.type.find()["num_dims"].value > 1: + raise NotImplementedError + else: + buffer = self.append(ir.GetAttr(value, "buffer")) + return self.append(ir.GetElem(buffer, index)) + elif builtins.is_listish(value.type): return self.append(ir.GetElem(value, index)) elif builtins.is_range(value.type): start = self.append(ir.GetAttr(value, "start")) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 19ed2c85d..2838a9e14 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -208,6 +208,14 @@ class Inferencer(algorithm.Visitor): if builtins.is_bytes(collection.type) or builtins.is_bytearray(collection.type): self._unify(element.type, builtins.get_iterable_elt(collection.type), element.loc, None) + elif builtins.is_array(collection.type): + array_type = collection.type.find() + elem_dims = array_type["num_dims"].value - 1 + if elem_dims > 0: + elem_type = builtins.TArray(array_type["elt"], types.TValue(elem_dims)) + else: + elem_type = array_type["elt"] + self._unify(element.type, elem_type, element.loc, collection.loc) elif builtins.is_iterable(collection.type) and not builtins.is_str(collection.type): rhs_type = collection.type.find() rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type}) diff --git a/artiq/test/lit/integration/array.py b/artiq/test/lit/integration/array.py index c02728334..81763c29e 100644 --- a/artiq/test/lit/integration/array.py +++ b/artiq/test/lit/integration/array.py @@ -4,14 +4,14 @@ ary = array([1, 2, 3]) assert len(ary) == 3 assert ary.shape == (3,) -# FIXME: Implement ndarray indexing -# assert [x*x for x in ary] == [1, 4, 9] +assert [x * x for x in ary] == [1, 4, 9] # Reassign to an existing value to disambiguate type of empty array. empty_array = array([1]) empty_array = array([]) assert len(empty_array) == 0 assert empty_array.shape == (0,) +assert [x * x for x in empty_array] == [] matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) assert len(matrix) == 2