forked from M-Labs/artiq
1
0
Fork 0

compiler: Assume array()s are always rectangular

This commit is contained in:
David Nadlinger 2020-08-08 20:35:04 +01:00
parent 8eddb9194a
commit 5472e830f6
5 changed files with 86 additions and 91 deletions

View File

@ -1561,9 +1561,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_block = after_invoke
def _get_array_offset(self, shape, indices):
last_stride = None
result = indices[0]
for dim, index in zip(shape[:-1], indices[1:]):
for dim, index in zip(shape[1:], indices[1:]):
result = self.append(ir.Arith(ast.Mult(loc=None), result, dim))
result = self.append(ir.Arith(ast.Add(loc=None), result, index))
return result
@ -2090,9 +2089,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
result_elt = result_type["elt"].find()
num_dims = result_type["num_dims"].value
# Derive shape from first element on each level (currently, type
# inference make sure arrays are always rectangular; in the future, we
# might want to insert a runtime check here).
# Derive shape from first element on each level (and fail later if the
# array is in fact jagged).
first_elt = None
lengths = []
for dim_idx in range(num_dims):
@ -2110,32 +2108,37 @@ class ARTIQIRGenerator(algorithm.Visitor):
buffer = self.append(
ir.Alloc([num_total_elts], result_type.attributes["buffer"]))
def body_gen(index):
# TODO: This is hilariously inefficient; we really want to emit a
# nested loop for the source and keep one running index for the
# target buffer.
indices = []
mod_idx = index
for dim_idx in reversed(range(1, num_dims)):
dim_len = self.append(ir.GetAttr(shape, dim_idx))
indices.append(
self.append(ir.Arith(ast.Mod(loc=None), mod_idx, dim_len)))
mod_idx = self.append(
ir.Arith(ast.FloorDiv(loc=None), mod_idx, dim_len))
indices.append(mod_idx)
elt = arg
for idx in reversed(indices):
elt = self.iterable_get(elt, idx)
self.append(ir.SetElem(buffer, index, elt))
return self.append(
ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
def assign_elems(outer_indices, indexed_arg):
if len(outer_indices) == num_dims:
dest_idx = self._get_array_offset(lengths, outer_indices)
self.append(ir.SetElem(buffer, dest_idx, indexed_arg))
else:
this_level_len = self.iterable_len(indexed_arg)
dim_idx = len(outer_indices)
if dim_idx > 0:
# Check for rectangularity (outermost index is never jagged,
# by definition).
result_len = self.append(ir.GetAttr(shape, dim_idx))
self._make_check(
self.append(ir.Compare(ast.Eq(loc=None), this_level_len, result_len)),
lambda a, b: self.alloc_exn(
builtins.TException("ValueError"),
ir.Constant(
"arrays must be rectangular (lengths were {0} vs. {1})",
builtins.TStr()), a, b),
params=[this_level_len, result_len],
loc=node.loc)
def body_gen(index):
elem = self.iterable_get(indexed_arg, index)
assign_elems(outer_indices + [index], elem)
return self.append(
ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, self._size_type)))
self._make_loop(
ir.Constant(0, self._size_type), lambda index: self.append(
ir.Compare(ast.Lt(loc=None), index, this_level_len)), body_gen)
assign_elems([], arg)
return self.append(ir.Alloc([buffer, shape], node.type))
else:
assert False

View File

@ -8,30 +8,6 @@ from .. import asttyped, types, builtins
from .typedtree_printer import TypedtreePrinter
def match_rectangular_list(elts):
num_elts = None
elt_type = None
all_child_elts = []
for e in elts:
if elt_type is None:
elt_type = e.type.find()
if not isinstance(e, asttyped.ListT):
return elt_type, 0
if num_elts is None:
num_elts = len(e.elts)
elif num_elts != len(e.elts):
return elt_type, 0
all_child_elts += e.elts
if not all_child_elts:
# This ultimately turned out to be a list (of list, of ...) of empty lists.
return elt_type["elt"], 1
elt, num_dims = match_rectangular_list(all_child_elts)
return elt, num_dims + 1
class Inferencer(algorithm.Visitor):
"""
:class:`Inferencer` infers types by recursively applying the unification
@ -862,29 +838,42 @@ class Inferencer(algorithm.Visitor):
if len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args
if builtins.is_iterable(arg.type):
# KLUDGE: Support multidimensional arary creation if lexically
# specified as a rectangular array of lists.
elt, num_dims = match_rectangular_list([arg])
if num_dims == 0:
# Not given as a list, so just default to 1 dimension.
elt = builtins.get_iterable_elt(arg.type)
num_dims = 1
self._unify(node.type,
builtins.TArray(elt, types.TValue(num_dims)),
node.loc, arg.loc)
elif types.is_var(arg.type):
pass # undetermined yet
else:
note = diagnostic.Diagnostic("note",
"this expression has type {type}",
{"type": types.TypePrinter().name(arg.type)},
arg.loc)
diag = diagnostic.Diagnostic("error",
# In the absence of any other information (there currently isn't a way
# to specify any), assume that all iterables are expandable into a
# (runtime-checked) rectangular array of the innermost element type.
elt = arg.type
num_dims = 0
result_dims = (node.type.find()["num_dims"].value
if builtins.is_array(node.type) else -1)
while True:
if num_dims == result_dims:
# If we already know the number of dimensions of the result,
# stop so we can disambiguate the (innermost) element type of
# the argument if it is still unknown (e.g. empty array).
break
if types.is_var(elt):
return # undetermined yet
if not builtins.is_iterable(elt):
break
num_dims += 1
elt = builtins.get_iterable_elt(elt)
if num_dims == 0:
note = diagnostic.Diagnostic(
"note", "this expression has type {type}",
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
diag = diagnostic.Diagnostic(
"error",
"the argument of {builtin}() must be of an iterable type",
{"builtin": typ.find().name},
node.func.loc, notes=[note])
node.func.loc,
notes=[note])
self.engine.process(diag)
return
self._unify(node.type,
builtins.TArray(elt, types.TValue(num_dims)),
node.loc, arg.loc)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "list"):

View File

@ -12,6 +12,7 @@ float_mat = array([[1.0, 2.0], [3.0, 4.0]])
@kernel
def entrypoint():
# TODO: These need to be runtime tests!
assert int_vec.shape == (3, )
assert int_vec[0] == 1
assert int_vec[1] == 2
@ -22,14 +23,14 @@ def entrypoint():
assert float_vec[1] == 2.0
assert float_vec[2] == 3.0
# assert int_mat.shape == (2, 2)
# assert int_mat[0][0] == 1
# assert int_mat[0][1] == 2
# assert int_mat[1][0] == 3
# assert int_mat[1][1] == 4
assert int_mat.shape == (2, 2)
assert int_mat[0][0] == 1
assert int_mat[0][1] == 2
assert int_mat[1][0] == 3
assert int_mat[1][1] == 4
# assert float_mat.shape == (2, 2)
# assert float_mat[0][0] == 1.0
# assert float_mat[0][1] == 2.0
# assert float_mat[1][0] == 3.0
# assert float_mat[1][1] == 4.0
assert float_mat.shape == (2, 2)
assert float_mat[0][0] == 1.0
assert float_mat[0][1] == 2.0
assert float_mat[1][0] == 3.0
assert float_mat[1][1] == 4.0

View File

@ -1,7 +1,9 @@
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: numpy.array(elt='a, num_dims=1)
# Nothing known, as there could be several more dimensions
# hidden from view by the array being empty.
# CHECK-L: ([]:list(elt='a)):'b
array([])
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=1)
@ -9,5 +11,6 @@ array([1, 2, 3])
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=2)
array([[1, 2, 3], [4, 5, 6]])
# CHECK-L: numpy.array(elt=list(elt=numpy.int?), num_dims=1)
# Jagged arrays produce runtime failure:
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=2)
array([[1, 2, 3], [4, 5]])

View File

@ -13,13 +13,12 @@ assert len(empty_array) == 0
assert empty_array.shape == (0,)
assert [x * x for x in empty_array] == []
# Creating a list from a generic iterable always generates an 1D array, as we can't
# check for rectangularity at compile time. (This could be changed to *assume*
# rectangularity and insert runtime checks instead.)
# Creating arrays from generic iterables, rectangularity is assumed (and ensured
# with runtime checks).
list_of_lists = [[1, 2], [3, 4]]
array_of_lists = array(list_of_lists)
assert array_of_lists.shape == (2,)
assert [x for x in array_of_lists] == list_of_lists
assert array_of_lists.shape == (2, 2)
assert [[y for y in x] for x in array_of_lists] == list_of_lists
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
assert len(matrix) == 2