mirror of
https://github.com/m-labs/artiq.git
synced 2025-01-13 04:18:55 +08:00
compiler: Assume array()s are always rectangular
This commit is contained in:
parent
8eddb9194a
commit
5472e830f6
@ -1561,9 +1561,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.current_block = after_invoke
|
||||
|
||||
def _get_array_offset(self, shape, indices):
|
||||
last_stride = None
|
||||
result = indices[0]
|
||||
for dim, index in zip(shape[:-1], indices[1:]):
|
||||
for dim, index in zip(shape[1:], indices[1:]):
|
||||
result = self.append(ir.Arith(ast.Mult(loc=None), result, dim))
|
||||
result = self.append(ir.Arith(ast.Add(loc=None), result, index))
|
||||
return result
|
||||
@ -2090,9 +2089,8 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
result_elt = result_type["elt"].find()
|
||||
num_dims = result_type["num_dims"].value
|
||||
|
||||
# Derive shape from first element on each level (currently, type
|
||||
# inference make sure arrays are always rectangular; in the future, we
|
||||
# might want to insert a runtime check here).
|
||||
# Derive shape from first element on each level (and fail later if the
|
||||
# array is in fact jagged).
|
||||
first_elt = None
|
||||
lengths = []
|
||||
for dim_idx in range(num_dims):
|
||||
@ -2110,32 +2108,37 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
buffer = self.append(
|
||||
ir.Alloc([num_total_elts], result_type.attributes["buffer"]))
|
||||
|
||||
def body_gen(index):
|
||||
# TODO: This is hilariously inefficient; we really want to emit a
|
||||
# nested loop for the source and keep one running index for the
|
||||
# target buffer.
|
||||
indices = []
|
||||
mod_idx = index
|
||||
for dim_idx in reversed(range(1, num_dims)):
|
||||
dim_len = self.append(ir.GetAttr(shape, dim_idx))
|
||||
indices.append(
|
||||
self.append(ir.Arith(ast.Mod(loc=None), mod_idx, dim_len)))
|
||||
mod_idx = self.append(
|
||||
ir.Arith(ast.FloorDiv(loc=None), mod_idx, dim_len))
|
||||
indices.append(mod_idx)
|
||||
|
||||
elt = arg
|
||||
for idx in reversed(indices):
|
||||
elt = self.iterable_get(elt, idx)
|
||||
self.append(ir.SetElem(buffer, index, elt))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, self._size_type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||
def assign_elems(outer_indices, indexed_arg):
|
||||
if len(outer_indices) == num_dims:
|
||||
dest_idx = self._get_array_offset(lengths, outer_indices)
|
||||
self.append(ir.SetElem(buffer, dest_idx, indexed_arg))
|
||||
else:
|
||||
this_level_len = self.iterable_len(indexed_arg)
|
||||
dim_idx = len(outer_indices)
|
||||
if dim_idx > 0:
|
||||
# Check for rectangularity (outermost index is never jagged,
|
||||
# by definition).
|
||||
result_len = self.append(ir.GetAttr(shape, dim_idx))
|
||||
self._make_check(
|
||||
self.append(ir.Compare(ast.Eq(loc=None), this_level_len, result_len)),
|
||||
lambda a, b: self.alloc_exn(
|
||||
builtins.TException("ValueError"),
|
||||
ir.Constant(
|
||||
"arrays must be rectangular (lengths were {0} vs. {1})",
|
||||
builtins.TStr()), a, b),
|
||||
params=[this_level_len, result_len],
|
||||
loc=node.loc)
|
||||
|
||||
def body_gen(index):
|
||||
elem = self.iterable_get(indexed_arg, index)
|
||||
assign_elems(outer_indices + [index], elem)
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, self._size_type)))
|
||||
self._make_loop(
|
||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), index, this_level_len)), body_gen)
|
||||
assign_elems([], arg)
|
||||
return self.append(ir.Alloc([buffer, shape], node.type))
|
||||
else:
|
||||
assert False
|
||||
|
@ -8,30 +8,6 @@ from .. import asttyped, types, builtins
|
||||
from .typedtree_printer import TypedtreePrinter
|
||||
|
||||
|
||||
def match_rectangular_list(elts):
|
||||
num_elts = None
|
||||
elt_type = None
|
||||
all_child_elts = []
|
||||
|
||||
for e in elts:
|
||||
if elt_type is None:
|
||||
elt_type = e.type.find()
|
||||
if not isinstance(e, asttyped.ListT):
|
||||
return elt_type, 0
|
||||
if num_elts is None:
|
||||
num_elts = len(e.elts)
|
||||
elif num_elts != len(e.elts):
|
||||
return elt_type, 0
|
||||
all_child_elts += e.elts
|
||||
|
||||
if not all_child_elts:
|
||||
# This ultimately turned out to be a list (of list, of ...) of empty lists.
|
||||
return elt_type["elt"], 1
|
||||
|
||||
elt, num_dims = match_rectangular_list(all_child_elts)
|
||||
return elt, num_dims + 1
|
||||
|
||||
|
||||
class Inferencer(algorithm.Visitor):
|
||||
"""
|
||||
:class:`Inferencer` infers types by recursively applying the unification
|
||||
@ -862,29 +838,42 @@ class Inferencer(algorithm.Visitor):
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg, = node.args
|
||||
|
||||
if builtins.is_iterable(arg.type):
|
||||
# KLUDGE: Support multidimensional arary creation if lexically
|
||||
# specified as a rectangular array of lists.
|
||||
elt, num_dims = match_rectangular_list([arg])
|
||||
if num_dims == 0:
|
||||
# Not given as a list, so just default to 1 dimension.
|
||||
elt = builtins.get_iterable_elt(arg.type)
|
||||
num_dims = 1
|
||||
self._unify(node.type,
|
||||
builtins.TArray(elt, types.TValue(num_dims)),
|
||||
node.loc, arg.loc)
|
||||
elif types.is_var(arg.type):
|
||||
pass # undetermined yet
|
||||
else:
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"this expression has type {type}",
|
||||
{"type": types.TypePrinter().name(arg.type)},
|
||||
arg.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
# In the absence of any other information (there currently isn't a way
|
||||
# to specify any), assume that all iterables are expandable into a
|
||||
# (runtime-checked) rectangular array of the innermost element type.
|
||||
elt = arg.type
|
||||
num_dims = 0
|
||||
result_dims = (node.type.find()["num_dims"].value
|
||||
if builtins.is_array(node.type) else -1)
|
||||
while True:
|
||||
if num_dims == result_dims:
|
||||
# If we already know the number of dimensions of the result,
|
||||
# stop so we can disambiguate the (innermost) element type of
|
||||
# the argument if it is still unknown (e.g. empty array).
|
||||
break
|
||||
if types.is_var(elt):
|
||||
return # undetermined yet
|
||||
if not builtins.is_iterable(elt):
|
||||
break
|
||||
num_dims += 1
|
||||
elt = builtins.get_iterable_elt(elt)
|
||||
|
||||
if num_dims == 0:
|
||||
note = diagnostic.Diagnostic(
|
||||
"note", "this expression has type {type}",
|
||||
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
|
||||
diag = diagnostic.Diagnostic(
|
||||
"error",
|
||||
"the argument of {builtin}() must be of an iterable type",
|
||||
{"builtin": typ.find().name},
|
||||
node.func.loc, notes=[note])
|
||||
node.func.loc,
|
||||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
self._unify(node.type,
|
||||
builtins.TArray(elt, types.TValue(num_dims)),
|
||||
node.loc, arg.loc)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "list"):
|
||||
|
@ -12,6 +12,7 @@ float_mat = array([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
# TODO: These need to be runtime tests!
|
||||
assert int_vec.shape == (3, )
|
||||
assert int_vec[0] == 1
|
||||
assert int_vec[1] == 2
|
||||
@ -22,14 +23,14 @@ def entrypoint():
|
||||
assert float_vec[1] == 2.0
|
||||
assert float_vec[2] == 3.0
|
||||
|
||||
# assert int_mat.shape == (2, 2)
|
||||
# assert int_mat[0][0] == 1
|
||||
# assert int_mat[0][1] == 2
|
||||
# assert int_mat[1][0] == 3
|
||||
# assert int_mat[1][1] == 4
|
||||
assert int_mat.shape == (2, 2)
|
||||
assert int_mat[0][0] == 1
|
||||
assert int_mat[0][1] == 2
|
||||
assert int_mat[1][0] == 3
|
||||
assert int_mat[1][1] == 4
|
||||
|
||||
# assert float_mat.shape == (2, 2)
|
||||
# assert float_mat[0][0] == 1.0
|
||||
# assert float_mat[0][1] == 2.0
|
||||
# assert float_mat[1][0] == 3.0
|
||||
# assert float_mat[1][1] == 4.0
|
||||
assert float_mat.shape == (2, 2)
|
||||
assert float_mat[0][0] == 1.0
|
||||
assert float_mat[0][1] == 2.0
|
||||
assert float_mat[1][0] == 3.0
|
||||
assert float_mat[1][1] == 4.0
|
||||
|
@ -1,7 +1,9 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
# CHECK-L: numpy.array(elt='a, num_dims=1)
|
||||
# Nothing known, as there could be several more dimensions
|
||||
# hidden from view by the array being empty.
|
||||
# CHECK-L: ([]:list(elt='a)):'b
|
||||
array([])
|
||||
|
||||
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=1)
|
||||
@ -9,5 +11,6 @@ array([1, 2, 3])
|
||||
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=2)
|
||||
array([[1, 2, 3], [4, 5, 6]])
|
||||
|
||||
# CHECK-L: numpy.array(elt=list(elt=numpy.int?), num_dims=1)
|
||||
# Jagged arrays produce runtime failure:
|
||||
# CHECK-L: numpy.array(elt=numpy.int?, num_dims=2)
|
||||
array([[1, 2, 3], [4, 5]])
|
||||
|
@ -13,13 +13,12 @@ assert len(empty_array) == 0
|
||||
assert empty_array.shape == (0,)
|
||||
assert [x * x for x in empty_array] == []
|
||||
|
||||
# Creating a list from a generic iterable always generates an 1D array, as we can't
|
||||
# check for rectangularity at compile time. (This could be changed to *assume*
|
||||
# rectangularity and insert runtime checks instead.)
|
||||
# Creating arrays from generic iterables, rectangularity is assumed (and ensured
|
||||
# with runtime checks).
|
||||
list_of_lists = [[1, 2], [3, 4]]
|
||||
array_of_lists = array(list_of_lists)
|
||||
assert array_of_lists.shape == (2,)
|
||||
assert [x for x in array_of_lists] == list_of_lists
|
||||
assert array_of_lists.shape == (2, 2)
|
||||
assert [[y for y in x] for x in array_of_lists] == list_of_lists
|
||||
|
||||
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
assert len(matrix) == 2
|
||||
|
Loading…
Reference in New Issue
Block a user