From d161fd5d84d9b724eb27bb21dabcf41f61fd2c62 Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Tue, 20 Oct 2020 01:29:51 +0200 Subject: [PATCH] compiler: Properly expand dimensions for array([]) with ndarray elements This matches host NumPy behaviour (and, in either case, was previously broken, as it still continued past the array element type). --- artiq/compiler/transforms/inferencer.py | 5 ++++- artiq/test/lit/integration/array.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index ddbd0bf24..ce0a269e1 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -892,7 +892,10 @@ class Inferencer(algorithm.Visitor): return # undetermined yet if not builtins.is_iterable(elt) or builtins.is_str(elt): break - num_dims += 1 + if builtins.is_array(elt): + num_dims += elt.find()["num_dims"].value + else: + num_dims += 1 elt = builtins.get_iterable_elt(elt) if explicit_dtype is not None: diff --git a/artiq/test/lit/integration/array.py b/artiq/test/lit/integration/array.py index 039e76372..512382cda 100644 --- a/artiq/test/lit/integration/array.py +++ b/artiq/test/lit/integration/array.py @@ -42,6 +42,9 @@ assert matrix[1, 0] == 4.0 assert matrix[1, 1] == 8.0 assert matrix[1, 2] == 6.0 +array_of_matrices = array([matrix, matrix]) +assert array_of_matrices.shape == (2, 2, 3) + three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) assert len(three_tensor) == 1 assert three_tensor.shape == (1, 2, 3)