compiler/inferencer: Detect rectangular array()s

Still needs support through all the rest of the compiler, and
support for higher-dimensional arrays.

Alternatively, we could always assume ndarrays of ndarrays
are rectangular (i.e. ban array/list element types), and
detect mismatch at runtime. This might turn out to be
preferrable to be able to construct matrices from rows/columns.

`array()` is disallowed for no particularly good reason but
numpy API compatibility.
This commit is contained in:
David Nadlinger 2020-07-25 20:09:42 +01:00
parent 6ea836183d
commit 56010c49fb
3 changed files with 44 additions and 4 deletions

View File

@ -7,6 +7,21 @@ from pythonparser import algorithm, diagnostic, ast
from .. import asttyped, types, builtins from .. import asttyped, types, builtins
from .typedtree_printer import TypedtreePrinter from .typedtree_printer import TypedtreePrinter
def is_rectangular_2d_list(node):
if not isinstance(node, asttyped.ListT):
return False
num_elts = None
for e in node.elts:
if not isinstance(e, asttyped.ListT):
return False
if num_elts is None:
num_elts = len(e.elts)
elif num_elts != len(e.elts):
return False
return True
class Inferencer(algorithm.Visitor): class Inferencer(algorithm.Visitor):
""" """
:class:`Inferencer` infers types by recursively applying the unification :class:`Inferencer` infers types by recursively applying the unification
@ -706,7 +721,6 @@ class Inferencer(algorithm.Visitor):
node.loc, None) node.loc, None)
elif types.is_builtin(typ, "array"): elif types.is_builtin(typ, "array"):
valid_forms = lambda: [ valid_forms = lambda: [
valid_form("array() -> array(elt='a)"),
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable") valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
] ]
@ -715,8 +729,10 @@ class Inferencer(algorithm.Visitor):
else: else:
assert False assert False
if len(node.args) == 0 and len(node.keywords) == 0: if (types.is_builtin(typ, "list") and len(node.args) == 0 and
pass # [] len(node.keywords) == 0):
# Mimic numpy and don't allow array() (but []).
pass
elif len(node.args) == 1 and len(node.keywords) == 0: elif len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args arg, = node.args
@ -732,8 +748,14 @@ class Inferencer(algorithm.Visitor):
{"typeb": printer.name(typeb)}, {"typeb": printer.name(typeb)},
locb) locb)
] ]
elt = arg.type.find().params["elt"]
if types.is_builtin(typ, "array") and builtins.is_listish(elt):
# KLUDGE: Support 2D arary creation if lexically specified
# as a rectangular array of lists.
if is_rectangular_2d_list(arg):
elt = elt.find().params["elt"]
self._unify(node.type.find().params["elt"], self._unify(node.type.find().params["elt"],
arg.type.find().params["elt"], elt,
node.loc, arg.loc, makenotes=makenotes) node.loc, arg.loc, makenotes=makenotes)
elif types.is_var(arg.type): elif types.is_var(arg.type):
pass # undetermined yet pass # undetermined yet

View File

@ -0,0 +1,13 @@
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: numpy.array(elt='a)
array([])
# CHECK-L: numpy.array(elt=numpy.int?)
array([1, 2, 3])
# CHECK-L: numpy.array(elt=numpy.int?)
array([[1, 2, 3], [4, 5, 6]])
# CHECK-L: numpy.array(elt=list(elt=numpy.int?))
array([[1, 2, 3], [4, 5]])

View File

@ -0,0 +1,5 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: ${LINE:+1}: error: array cannot be invoked with the arguments ()
a = array()