compiler: Properly expand dimensions for array([]) with ndarray elements

This matches host NumPy behaviour (and, in either case, was
previously broken, as it still continued past the array element
type).
This commit is contained in:
David Nadlinger 2020-10-20 01:29:51 +02:00
parent 94489f9183
commit d161fd5d84
2 changed files with 7 additions and 1 deletions

View File

@ -892,7 +892,10 @@ class Inferencer(algorithm.Visitor):
return # undetermined yet
if not builtins.is_iterable(elt) or builtins.is_str(elt):
break
num_dims += 1
if builtins.is_array(elt):
num_dims += elt.find()["num_dims"].value
else:
num_dims += 1
elt = builtins.get_iterable_elt(elt)
if explicit_dtype is not None:

View File

@ -42,6 +42,9 @@ assert matrix[1, 0] == 4.0
assert matrix[1, 1] == 8.0
assert matrix[1, 2] == 6.0
array_of_matrices = array([matrix, matrix])
assert array_of_matrices.shape == (2, 2, 3)
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
assert len(three_tensor) == 1
assert three_tensor.shape == (1, 2, 3)