diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index b5662aec6..54b25e719 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -315,6 +315,9 @@ def is_iterable(typ): return is_listish(typ) or is_range(typ) def get_iterable_elt(typ): + # TODO: Arrays count as listish, but this returns the innermost element type for + # n-dimensional arrays, rather than the n-1 dimensional result of iterating over + # the first axis, which makes the name a bit misleading. if is_str(typ) or is_bytes(typ) or is_bytearray(typ): return TInt(types.TValue(8)) elif types._is_pointer(typ) or is_iterable(typ): diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index d77db91d4..f72f01127 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -8,6 +8,28 @@ from .. import asttyped, types, builtins from .typedtree_printer import TypedtreePrinter +def is_nested_empty_list(node): + """If the passed AST node is an empty list, or a regularly nested list thereof, + returns the number of nesting layers, or ``None`` otherwise. + + For instance, ``is_nested_empty_list([]) == 1`` and + ``is_nested_empty_list([[], []]) == 2``, but + ``is_nested_empty_list([[[]], []]) == None`` as the number of nesting layers doesn't + match. + """ + if not isinstance(node, ast.List): + return None + if not node.elts: + return 1 + result = is_nested_empty_list(node.elts[0]) + if result is None: + return None + for elt in node.elts[:1]: + if result != is_nested_empty_list(elt): + return None + return result + 1 + + class Inferencer(algorithm.Visitor): """ :class:`Inferencer` infers types by recursively applying the unification @@ -891,28 +913,45 @@ class Inferencer(algorithm.Visitor): if len(node.args) == 1 and keywords_acceptable: arg, = node.args - # In the absence of any other information (there currently isn't a way - # to specify any), assume that all iterables are expandable into a - # (runtime-checked) rectangular array of the innermost element type. - elt = arg.type - num_dims = 0 - result_dims = (node.type.find()["num_dims"].value - if builtins.is_array(node.type) else -1) - while True: - if num_dims == result_dims: - # If we already know the number of dimensions of the result, - # stop so we can disambiguate the (innermost) element type of - # the argument if it is still unknown (e.g. empty array). - break - if types.is_var(elt): - return # undetermined yet - if not builtins.is_iterable(elt) or builtins.is_str(elt): - break - if builtins.is_array(elt): - num_dims += elt.find()["num_dims"].value - else: - num_dims += 1 - elt = builtins.get_iterable_elt(elt) + num_empty_dims = is_nested_empty_list(arg) + if num_empty_dims is not None: + # As a special case, following the behaviour of numpy.array (and + # repr() on ndarrays), consider empty lists to be exactly of the + # number of dimensions given, instead of potentially containing an + # unknown number of extra dimensions. + num_dims = num_empty_dims + + # The ultimate element type will be TVar initially, but we might be + # able to resolve it from context. + elt = arg.type + for _ in range(num_dims): + assert builtins.is_list(elt) + elt = elt.find()["elt"] + else: + # In the absence of any other information (there currently isn't a way + # to specify any), assume that all iterables are expandable into a + # (runtime-checked) rectangular array of the innermost element type. + elt = arg.type + num_dims = 0 + expected_dims = (node.type.find()["num_dims"].value + if builtins.is_array(node.type) else -1) + while True: + if num_dims == expected_dims: + # If we already know the number of dimensions of the result, + # stop so we can disambiguate the (innermost) element type of + # the argument if it is still unknown. + break + if types.is_var(elt): + # Can't make progress here because we don't know how many more + # dimensions might be "hidden" inside. + return + if not builtins.is_iterable(elt) or builtins.is_str(elt): + break + if builtins.is_array(elt): + num_dims += elt.find()["num_dims"].value + else: + num_dims += 1 + elt = builtins.get_iterable_elt(elt) if explicit_dtype is not None: # TODO: Factor out type detection; support quoted type constructors diff --git a/artiq/test/lit/inferencer/array_creation.py b/artiq/test/lit/inferencer/array_creation.py index e3e00a254..824150c22 100644 --- a/artiq/test/lit/inferencer/array_creation.py +++ b/artiq/test/lit/inferencer/array_creation.py @@ -1,10 +1,10 @@ # RUN: %python -m artiq.compiler.testbench.inferencer %s >%t # RUN: OutputCheck %s --file-to-check=%t -# Nothing known, as there could be several more dimensions -# hidden from view by the array being empty. -# CHECK-L: ([]:list(elt='a)):'b +# CHECK-L: numpy.array(elt='a, num_dims=1) array([]) +# CHECK-L: numpy.array(elt='b, num_dims=2) +array([[], []]) # CHECK-L: numpy.array(elt=numpy.int?, num_dims=1) array([1, 2, 3])