forked from M-Labs/artiq
compiler: Properly expand dimensions for array([]) with ndarray elements
This matches host NumPy behaviour (and, in either case, was previously broken, as it still continued past the array element type).
This commit is contained in:
parent
94489f9183
commit
d161fd5d84
|
@ -892,6 +892,9 @@ class Inferencer(algorithm.Visitor):
|
||||||
return # undetermined yet
|
return # undetermined yet
|
||||||
if not builtins.is_iterable(elt) or builtins.is_str(elt):
|
if not builtins.is_iterable(elt) or builtins.is_str(elt):
|
||||||
break
|
break
|
||||||
|
if builtins.is_array(elt):
|
||||||
|
num_dims += elt.find()["num_dims"].value
|
||||||
|
else:
|
||||||
num_dims += 1
|
num_dims += 1
|
||||||
elt = builtins.get_iterable_elt(elt)
|
elt = builtins.get_iterable_elt(elt)
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,9 @@ assert matrix[1, 0] == 4.0
|
||||||
assert matrix[1, 1] == 8.0
|
assert matrix[1, 1] == 8.0
|
||||||
assert matrix[1, 2] == 6.0
|
assert matrix[1, 2] == 6.0
|
||||||
|
|
||||||
|
array_of_matrices = array([matrix, matrix])
|
||||||
|
assert array_of_matrices.shape == (2, 2, 3)
|
||||||
|
|
||||||
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
||||||
assert len(three_tensor) == 1
|
assert len(three_tensor) == 1
|
||||||
assert three_tensor.shape == (1, 2, 3)
|
assert three_tensor.shape == (1, 2, 3)
|
||||||
|
|
Loading…
Reference in New Issue