forked from M-Labs/artiq
compiler: Parametrize TArray in number of dimensions
This commit is contained in:
parent
632c5bc937
commit
bc17bb4d1a
|
@ -82,17 +82,24 @@ class TList(types.TMono):
|
|||
super().__init__("list", {"elt": elt})
|
||||
|
||||
class TArray(types.TMono):
|
||||
def __init__(self, elt=None):
|
||||
def __init__(self, elt=None, num_dims=types.TValue(1)):
|
||||
if elt is None:
|
||||
elt = types.TVar()
|
||||
super().__init__("array", {"elt": elt})
|
||||
# For now, enforce number of dimensions to be known, as we'd otherwise
|
||||
# need to implement custom unification logic for the type of `shape`.
|
||||
# Default to 1 to keep compatibility with old user code from before
|
||||
# multidimensional array support.
|
||||
assert isinstance(num_dims.value, int), "Number of dimensions must be resolved"
|
||||
|
||||
super().__init__("array", {"elt": elt, "num_dims": num_dims})
|
||||
self.attributes = OrderedDict([
|
||||
("shape", TList(TInt32())),
|
||||
("shape", types.TTuple([TInt32()] * num_dims.value)),
|
||||
("buffer", TList(elt)),
|
||||
])
|
||||
|
||||
def _array_printer(typ, printer, depth, max_depth):
|
||||
return "numpy.array(elt={})".format(printer.name(typ["elt"], depth, max_depth))
|
||||
return "numpy.array(elt={}, num_dims={})".format(
|
||||
printer.name(typ["elt"], depth, max_depth), typ["num_dims"].value)
|
||||
types.TypePrinter.custom_printers["array"] = _array_printer
|
||||
|
||||
class TRange(types.TMono):
|
||||
|
|
|
@ -7,6 +7,7 @@ semantics explicitly.
|
|||
"""
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from functools import reduce
|
||||
from pythonparser import algorithm, diagnostic, ast
|
||||
from .. import types, builtins, asttyped, ir, iodelay
|
||||
|
||||
|
@ -1665,47 +1666,32 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
result_type = node.type.find()
|
||||
arg = self.visit(node.args[0])
|
||||
|
||||
num_dims = 0
|
||||
result_elt = result_type["elt"].find()
|
||||
inner_type = arg.type.find()
|
||||
while True:
|
||||
if inner_type == result_elt:
|
||||
# TODO: What about types needing coercion (e.g. int32 to int64)?
|
||||
break
|
||||
assert builtins.is_iterable(inner_type)
|
||||
num_dims += 1
|
||||
inner_type = builtins.get_iterable_elt(inner_type)
|
||||
num_dims = result_type["num_dims"].value
|
||||
|
||||
# Derive shape from first element on each level (currently, type
|
||||
# inference make sure arrays are always rectangular; in the future, we
|
||||
# might want to insert a runtime check here).
|
||||
#
|
||||
# While we are at it, also total up overall number of elements
|
||||
shape = self.append(
|
||||
ir.Alloc([ir.Constant(num_dims, self._size_type)],
|
||||
result_type.attributes["shape"]))
|
||||
first_elt = arg
|
||||
dim_idx = 0
|
||||
num_total_elts = None
|
||||
while True:
|
||||
length = self.iterable_len(first_elt)
|
||||
self.append(
|
||||
ir.SetElem(shape, ir.Constant(dim_idx, length.type), length))
|
||||
if num_total_elts is None:
|
||||
num_total_elts = length
|
||||
first_elt = None
|
||||
lengths = []
|
||||
for dim_idx in range(num_dims):
|
||||
if first_elt is None:
|
||||
first_elt = arg
|
||||
else:
|
||||
num_total_elts = self.append(
|
||||
ir.Arith(ast.Mult(loc=None), num_total_elts, length))
|
||||
first_elt = self.iterable_get(first_elt,
|
||||
ir.Constant(0, self._size_type))
|
||||
lengths.append(self.iterable_len(first_elt))
|
||||
|
||||
dim_idx += 1
|
||||
if dim_idx == num_dims:
|
||||
break
|
||||
first_elt = self.iterable_get(first_elt,
|
||||
ir.Constant(0, length.type))
|
||||
num_total_elts = reduce(
|
||||
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||
lengths[1:], lengths[0])
|
||||
|
||||
shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"]))
|
||||
|
||||
# Assign buffer from nested iterables.
|
||||
buffer = self.append(
|
||||
ir.Alloc([num_total_elts], result_type.attributes["buffer"]))
|
||||
|
||||
def body_gen(index):
|
||||
# TODO: This is hilariously inefficient; we really want to emit a
|
||||
# nested loop for the source and keep one running index for the
|
||||
|
@ -1713,9 +1699,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
indices = []
|
||||
mod_idx = index
|
||||
for dim_idx in reversed(range(1, num_dims)):
|
||||
dim_len = self.append(ir.GetElem(shape, ir.Constant(dim_idx, self._size_type)))
|
||||
indices.append(self.append(ir.Arith(ast.Mod(loc=None), mod_idx, dim_len)))
|
||||
mod_idx = self.append(ir.Arith(ast.FloorDiv(loc=None), mod_idx, dim_len))
|
||||
dim_len = self.append(ir.GetAttr(shape, dim_idx))
|
||||
indices.append(
|
||||
self.append(ir.Arith(ast.Mod(loc=None), mod_idx, dim_len)))
|
||||
mod_idx = self.append(
|
||||
ir.Arith(ast.FloorDiv(loc=None), mod_idx, dim_len))
|
||||
indices.append(mod_idx)
|
||||
|
||||
elt = arg
|
||||
|
@ -1723,9 +1711,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
|||
elt = self.iterable_get(elt, idx)
|
||||
self.append(ir.SetElem(buffer, index, elt))
|
||||
return self.append(
|
||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, length.type)))
|
||||
ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, self._size_type)))
|
||||
|
||||
self._make_loop(
|
||||
ir.Constant(0, length.type), lambda index: self.append(
|
||||
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||
|
||||
return self.append(ir.Alloc([shape, buffer], node.type))
|
||||
|
|
|
@ -8,18 +8,28 @@ from .. import asttyped, types, builtins
|
|||
from .typedtree_printer import TypedtreePrinter
|
||||
|
||||
|
||||
def is_rectangular_2d_list(node):
|
||||
if not isinstance(node, asttyped.ListT):
|
||||
return False
|
||||
def match_rectangular_list(elts):
|
||||
num_elts = None
|
||||
for e in node.elts:
|
||||
elt_type = None
|
||||
all_child_elts = []
|
||||
|
||||
for e in elts:
|
||||
if elt_type is None:
|
||||
elt_type = e.type.find()
|
||||
if not isinstance(e, asttyped.ListT):
|
||||
return False
|
||||
return elt_type, 0
|
||||
if num_elts is None:
|
||||
num_elts = len(e.elts)
|
||||
elif num_elts != len(e.elts):
|
||||
return False
|
||||
return True
|
||||
return elt_type, 0
|
||||
all_child_elts += e.elts
|
||||
|
||||
if not all_child_elts:
|
||||
# This ultimately turned out to be a list (of list, of ...) of empty lists.
|
||||
return elt_type["elt"], 1
|
||||
|
||||
elt, num_dims = match_rectangular_list(all_child_elts)
|
||||
return elt, num_dims + 1
|
||||
|
||||
|
||||
class Inferencer(algorithm.Visitor):
|
||||
|
@ -710,29 +720,45 @@ class Inferencer(algorithm.Visitor):
|
|||
"strings currently cannot be constructed", {},
|
||||
node.loc)
|
||||
self.engine.process(diag)
|
||||
elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
|
||||
if types.is_builtin(typ, "list"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("list() -> list(elt='a)"),
|
||||
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
||||
]
|
||||
elif types.is_builtin(typ, "array"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TList(),
|
||||
node.loc, None)
|
||||
elif types.is_builtin(typ, "array"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
|
||||
]
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg, = node.args
|
||||
|
||||
self._unify(node.type, builtins.TArray(),
|
||||
node.loc, None)
|
||||
if builtins.is_iterable(arg.type):
|
||||
# KLUDGE: Support multidimensional arary creation if lexically
|
||||
# specified as a rectangular array of lists.
|
||||
elt, num_dims = match_rectangular_list([arg])
|
||||
self._unify(node.type,
|
||||
builtins.TArray(elt, types.TValue(num_dims)),
|
||||
node.loc, arg.loc)
|
||||
elif types.is_var(arg.type):
|
||||
pass # undetermined yet
|
||||
else:
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"this expression has type {type}",
|
||||
{"type": types.TypePrinter().name(arg.type)},
|
||||
arg.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"the argument of {builtin}() must be of an iterable type",
|
||||
{"builtin": typ.find().name},
|
||||
node.func.loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
else:
|
||||
assert False
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "list"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("list() -> list(elt='a)"),
|
||||
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
||||
]
|
||||
|
||||
if (types.is_builtin(typ, "list") and len(node.args) == 0 and
|
||||
len(node.keywords) == 0):
|
||||
# Mimic numpy and don't allow array() (but []).
|
||||
pass
|
||||
self._unify(node.type, builtins.TList(), node.loc, None)
|
||||
|
||||
if len(node.args) == 0 and len(node.keywords) == 0:
|
||||
pass # []
|
||||
elif len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg, = node.args
|
||||
|
||||
|
@ -748,14 +774,8 @@ class Inferencer(algorithm.Visitor):
|
|||
{"typeb": printer.name(typeb)},
|
||||
locb)
|
||||
]
|
||||
elt = arg.type.find().params["elt"]
|
||||
if types.is_builtin(typ, "array") and builtins.is_listish(elt):
|
||||
# KLUDGE: Support 2D arary creation if lexically specified
|
||||
# as a rectangular array of lists.
|
||||
if is_rectangular_2d_list(arg):
|
||||
elt = elt.find().params["elt"]
|
||||
self._unify(node.type.find().params["elt"],
|
||||
elt,
|
||||
arg.type.find().params["elt"],
|
||||
node.loc, arg.loc, makenotes=makenotes)
|
||||
elif types.is_var(arg.type):
|
||||
pass # undetermined yet
|
||||
|
|
|
@ -1173,7 +1173,7 @@ class LLVMIRGenerator:
|
|||
if builtins.is_array(collection.type):
|
||||
# Return length of outermost dimension.
|
||||
shape = self.llbuilder.extract_value(self.map(collection), 0)
|
||||
return self.llbuilder.load(self.llbuilder.extract_value(shape, 0))
|
||||
return self.llbuilder.extract_value(shape, 0)
|
||||
return self.llbuilder.extract_value(self.map(collection), 1)
|
||||
elif insn.op in ("printf", "rtio_log"):
|
||||
# We only get integers, floats, pointers and strings here.
|
||||
|
|
|
@ -50,3 +50,9 @@ class ConstnessValidator(algorithm.Visitor):
|
|||
node.loc)
|
||||
self.engine.process(diag)
|
||||
return
|
||||
if builtins.is_array(typ):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"array attributes cannot be assigned to",
|
||||
{}, node.loc)
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
|
|
@ -3,3 +3,7 @@
|
|||
|
||||
# CHECK-L: ${LINE:+1}: error: array cannot be invoked with the arguments ()
|
||||
a = array()
|
||||
|
||||
b = array([1, 2, 3])
|
||||
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
||||
b.shape = (5, )
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
ary = array([1, 2, 3])
|
||||
assert len(ary) == 3
|
||||
assert ary.shape == [3]
|
||||
assert ary.shape == (3,)
|
||||
# FIXME: Implement ndarray indexing
|
||||
# assert [x*x for x in ary] == [1, 4, 9]
|
||||
|
||||
|
@ -11,8 +11,12 @@ assert ary.shape == [3]
|
|||
empty_array = array([1])
|
||||
empty_array = array([])
|
||||
assert len(empty_array) == 0
|
||||
assert empty_array.shape == [0]
|
||||
assert empty_array.shape == (0,)
|
||||
|
||||
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
assert len(matrix) == 2
|
||||
assert matrix.shape == [2, 3]
|
||||
assert matrix.shape == (2, 3)
|
||||
|
||||
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
||||
assert len(three_tensor) == 1
|
||||
assert three_tensor.shape == (1, 2, 3)
|
||||
|
|
Loading…
Reference in New Issue