compiler: Change type inference rules for empty array() calls

array([...]), the constructor for NumPy arrays, currently has the
status of some weird kind of macro in ARTIQ Python, as it needs
to determine the number of dimensions in the resulting array
type, which is a fixed type parameter on which inference cannot
be performed.

This leads to an ambiguity for empty lists, which could contain
elements of arbitrary type, including other lists (which would
add to the number of dimensions).

Previously, I had chosen to make array([]) to be of completely
indeterminate type for this reason. However, this is different
to how the call behaves in host NumPy, where this is a well-formed
call creating an empty 1D array (or 2D for array([[], []]), etc.).

This commit adds special matching for (recursive lists of) empty
ListT AST nodes to treat them as scalar dimensions, with the
element type still unknown.

This also happens to fix type inference for embedding empty 1D
NumPy arrays from host object attributes, although multi-dimensional
arrays will still require work (see GitHub #1633).

GitHub: Fixes #1626.
This commit is contained in:
David Nadlinger 2021-03-14 22:36:40 +00:00 committed by Sebastien Bourdeauducq
parent 925014689e
commit c1413a9945
3 changed files with 67 additions and 25 deletions

View File

@ -315,6 +315,9 @@ def is_iterable(typ):
return is_listish(typ) or is_range(typ) return is_listish(typ) or is_range(typ)
def get_iterable_elt(typ): def get_iterable_elt(typ):
# TODO: Arrays count as listish, but this returns the innermost element type for
# n-dimensional arrays, rather than the n-1 dimensional result of iterating over
# the first axis, which makes the name a bit misleading.
if is_str(typ) or is_bytes(typ) or is_bytearray(typ): if is_str(typ) or is_bytes(typ) or is_bytearray(typ):
return TInt(types.TValue(8)) return TInt(types.TValue(8))
elif types._is_pointer(typ) or is_iterable(typ): elif types._is_pointer(typ) or is_iterable(typ):

View File

@ -8,6 +8,28 @@ from .. import asttyped, types, builtins
from .typedtree_printer import TypedtreePrinter from .typedtree_printer import TypedtreePrinter
def is_nested_empty_list(node):
"""If the passed AST node is an empty list, or a regularly nested list thereof,
returns the number of nesting layers, or ``None`` otherwise.
For instance, ``is_nested_empty_list([]) == 1`` and
``is_nested_empty_list([[], []]) == 2``, but
``is_nested_empty_list([[[]], []]) == None`` as the number of nesting layers doesn't
match.
"""
if not isinstance(node, ast.List):
return None
if not node.elts:
return 1
result = is_nested_empty_list(node.elts[0])
if result is None:
return None
for elt in node.elts[:1]:
if result != is_nested_empty_list(elt):
return None
return result + 1
class Inferencer(algorithm.Visitor): class Inferencer(algorithm.Visitor):
""" """
:class:`Inferencer` infers types by recursively applying the unification :class:`Inferencer` infers types by recursively applying the unification
@ -891,28 +913,45 @@ class Inferencer(algorithm.Visitor):
if len(node.args) == 1 and keywords_acceptable: if len(node.args) == 1 and keywords_acceptable:
arg, = node.args arg, = node.args
# In the absence of any other information (there currently isn't a way num_empty_dims = is_nested_empty_list(arg)
# to specify any), assume that all iterables are expandable into a if num_empty_dims is not None:
# (runtime-checked) rectangular array of the innermost element type. # As a special case, following the behaviour of numpy.array (and
elt = arg.type # repr() on ndarrays), consider empty lists to be exactly of the
num_dims = 0 # number of dimensions given, instead of potentially containing an
result_dims = (node.type.find()["num_dims"].value # unknown number of extra dimensions.
if builtins.is_array(node.type) else -1) num_dims = num_empty_dims
while True:
if num_dims == result_dims: # The ultimate element type will be TVar initially, but we might be
# If we already know the number of dimensions of the result, # able to resolve it from context.
# stop so we can disambiguate the (innermost) element type of elt = arg.type
# the argument if it is still unknown (e.g. empty array). for _ in range(num_dims):
break assert builtins.is_list(elt)
if types.is_var(elt): elt = elt.find()["elt"]
return # undetermined yet else:
if not builtins.is_iterable(elt) or builtins.is_str(elt): # In the absence of any other information (there currently isn't a way
break # to specify any), assume that all iterables are expandable into a
if builtins.is_array(elt): # (runtime-checked) rectangular array of the innermost element type.
num_dims += elt.find()["num_dims"].value elt = arg.type
else: num_dims = 0
num_dims += 1 expected_dims = (node.type.find()["num_dims"].value
elt = builtins.get_iterable_elt(elt) if builtins.is_array(node.type) else -1)
while True:
if num_dims == expected_dims:
# If we already know the number of dimensions of the result,
# stop so we can disambiguate the (innermost) element type of
# the argument if it is still unknown.
break
if types.is_var(elt):
# Can't make progress here because we don't know how many more
# dimensions might be "hidden" inside.
return
if not builtins.is_iterable(elt) or builtins.is_str(elt):
break
if builtins.is_array(elt):
num_dims += elt.find()["num_dims"].value
else:
num_dims += 1
elt = builtins.get_iterable_elt(elt)
if explicit_dtype is not None: if explicit_dtype is not None:
# TODO: Factor out type detection; support quoted type constructors # TODO: Factor out type detection; support quoted type constructors

View File

@ -1,10 +1,10 @@
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t # RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
# RUN: OutputCheck %s --file-to-check=%t # RUN: OutputCheck %s --file-to-check=%t
# Nothing known, as there could be several more dimensions # CHECK-L: numpy.array(elt='a, num_dims=1)
# hidden from view by the array being empty.
# CHECK-L: ([]:list(elt='a)):'b
array([]) array([])
# CHECK-L: numpy.array(elt='b, num_dims=2)
array([[], []])
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=1) # CHECK-L: numpy.array(elt=numpy.int?, num_dims=1)
array([1, 2, 3]) array([1, 2, 3])