forked from M-Labs/artiq
compiler: Properly expand dimensions for array([]) with ndarray elements
This matches host NumPy behaviour (and, in either case, was previously broken, as it still continued past the array element type).
This commit is contained in:
parent
94489f9183
commit
d161fd5d84
|
@ -892,6 +892,9 @@ class Inferencer(algorithm.Visitor):
|
|||
return # undetermined yet
|
||||
if not builtins.is_iterable(elt) or builtins.is_str(elt):
|
||||
break
|
||||
if builtins.is_array(elt):
|
||||
num_dims += elt.find()["num_dims"].value
|
||||
else:
|
||||
num_dims += 1
|
||||
elt = builtins.get_iterable_elt(elt)
|
||||
|
||||
|
|
|
@ -42,6 +42,9 @@ assert matrix[1, 0] == 4.0
|
|||
assert matrix[1, 1] == 8.0
|
||||
assert matrix[1, 2] == 6.0
|
||||
|
||||
array_of_matrices = array([matrix, matrix])
|
||||
assert array_of_matrices.shape == (2, 2, 3)
|
||||
|
||||
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
||||
assert len(three_tensor) == 1
|
||||
assert three_tensor.shape == (1, 2, 3)
|
||||
|
|
Loading…
Reference in New Issue