forked from M-Labs/artiq
compiler: Parametrize TArray in number of dimensions
This commit is contained in:
parent
632c5bc937
commit
bc17bb4d1a
|
@ -82,17 +82,24 @@ class TList(types.TMono):
|
||||||
super().__init__("list", {"elt": elt})
|
super().__init__("list", {"elt": elt})
|
||||||
|
|
||||||
class TArray(types.TMono):
|
class TArray(types.TMono):
|
||||||
def __init__(self, elt=None):
|
def __init__(self, elt=None, num_dims=types.TValue(1)):
|
||||||
if elt is None:
|
if elt is None:
|
||||||
elt = types.TVar()
|
elt = types.TVar()
|
||||||
super().__init__("array", {"elt": elt})
|
# For now, enforce number of dimensions to be known, as we'd otherwise
|
||||||
|
# need to implement custom unification logic for the type of `shape`.
|
||||||
|
# Default to 1 to keep compatibility with old user code from before
|
||||||
|
# multidimensional array support.
|
||||||
|
assert isinstance(num_dims.value, int), "Number of dimensions must be resolved"
|
||||||
|
|
||||||
|
super().__init__("array", {"elt": elt, "num_dims": num_dims})
|
||||||
self.attributes = OrderedDict([
|
self.attributes = OrderedDict([
|
||||||
("shape", TList(TInt32())),
|
("shape", types.TTuple([TInt32()] * num_dims.value)),
|
||||||
("buffer", TList(elt)),
|
("buffer", TList(elt)),
|
||||||
])
|
])
|
||||||
|
|
||||||
def _array_printer(typ, printer, depth, max_depth):
|
def _array_printer(typ, printer, depth, max_depth):
|
||||||
return "numpy.array(elt={})".format(printer.name(typ["elt"], depth, max_depth))
|
return "numpy.array(elt={}, num_dims={})".format(
|
||||||
|
printer.name(typ["elt"], depth, max_depth), typ["num_dims"].value)
|
||||||
types.TypePrinter.custom_printers["array"] = _array_printer
|
types.TypePrinter.custom_printers["array"] = _array_printer
|
||||||
|
|
||||||
class TRange(types.TMono):
|
class TRange(types.TMono):
|
||||||
|
|
|
@ -7,6 +7,7 @@ semantics explicitly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
|
from functools import reduce
|
||||||
from pythonparser import algorithm, diagnostic, ast
|
from pythonparser import algorithm, diagnostic, ast
|
||||||
from .. import types, builtins, asttyped, ir, iodelay
|
from .. import types, builtins, asttyped, ir, iodelay
|
||||||
|
|
||||||
|
@ -1665,47 +1666,32 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
result_type = node.type.find()
|
result_type = node.type.find()
|
||||||
arg = self.visit(node.args[0])
|
arg = self.visit(node.args[0])
|
||||||
|
|
||||||
num_dims = 0
|
|
||||||
result_elt = result_type["elt"].find()
|
result_elt = result_type["elt"].find()
|
||||||
inner_type = arg.type.find()
|
num_dims = result_type["num_dims"].value
|
||||||
while True:
|
|
||||||
if inner_type == result_elt:
|
|
||||||
# TODO: What about types needing coercion (e.g. int32 to int64)?
|
|
||||||
break
|
|
||||||
assert builtins.is_iterable(inner_type)
|
|
||||||
num_dims += 1
|
|
||||||
inner_type = builtins.get_iterable_elt(inner_type)
|
|
||||||
|
|
||||||
# Derive shape from first element on each level (currently, type
|
# Derive shape from first element on each level (currently, type
|
||||||
# inference make sure arrays are always rectangular; in the future, we
|
# inference make sure arrays are always rectangular; in the future, we
|
||||||
# might want to insert a runtime check here).
|
# might want to insert a runtime check here).
|
||||||
#
|
first_elt = None
|
||||||
# While we are at it, also total up overall number of elements
|
lengths = []
|
||||||
shape = self.append(
|
for dim_idx in range(num_dims):
|
||||||
ir.Alloc([ir.Constant(num_dims, self._size_type)],
|
if first_elt is None:
|
||||||
result_type.attributes["shape"]))
|
first_elt = arg
|
||||||
first_elt = arg
|
|
||||||
dim_idx = 0
|
|
||||||
num_total_elts = None
|
|
||||||
while True:
|
|
||||||
length = self.iterable_len(first_elt)
|
|
||||||
self.append(
|
|
||||||
ir.SetElem(shape, ir.Constant(dim_idx, length.type), length))
|
|
||||||
if num_total_elts is None:
|
|
||||||
num_total_elts = length
|
|
||||||
else:
|
else:
|
||||||
num_total_elts = self.append(
|
first_elt = self.iterable_get(first_elt,
|
||||||
ir.Arith(ast.Mult(loc=None), num_total_elts, length))
|
ir.Constant(0, self._size_type))
|
||||||
|
lengths.append(self.iterable_len(first_elt))
|
||||||
|
|
||||||
dim_idx += 1
|
num_total_elts = reduce(
|
||||||
if dim_idx == num_dims:
|
lambda l, r: self.append(ir.Arith(ast.Mult(loc=None), l, r)),
|
||||||
break
|
lengths[1:], lengths[0])
|
||||||
first_elt = self.iterable_get(first_elt,
|
|
||||||
ir.Constant(0, length.type))
|
shape = self.append(ir.Alloc(lengths, result_type.attributes["shape"]))
|
||||||
|
|
||||||
# Assign buffer from nested iterables.
|
# Assign buffer from nested iterables.
|
||||||
buffer = self.append(
|
buffer = self.append(
|
||||||
ir.Alloc([num_total_elts], result_type.attributes["buffer"]))
|
ir.Alloc([num_total_elts], result_type.attributes["buffer"]))
|
||||||
|
|
||||||
def body_gen(index):
|
def body_gen(index):
|
||||||
# TODO: This is hilariously inefficient; we really want to emit a
|
# TODO: This is hilariously inefficient; we really want to emit a
|
||||||
# nested loop for the source and keep one running index for the
|
# nested loop for the source and keep one running index for the
|
||||||
|
@ -1713,9 +1699,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
indices = []
|
indices = []
|
||||||
mod_idx = index
|
mod_idx = index
|
||||||
for dim_idx in reversed(range(1, num_dims)):
|
for dim_idx in reversed(range(1, num_dims)):
|
||||||
dim_len = self.append(ir.GetElem(shape, ir.Constant(dim_idx, self._size_type)))
|
dim_len = self.append(ir.GetAttr(shape, dim_idx))
|
||||||
indices.append(self.append(ir.Arith(ast.Mod(loc=None), mod_idx, dim_len)))
|
indices.append(
|
||||||
mod_idx = self.append(ir.Arith(ast.FloorDiv(loc=None), mod_idx, dim_len))
|
self.append(ir.Arith(ast.Mod(loc=None), mod_idx, dim_len)))
|
||||||
|
mod_idx = self.append(
|
||||||
|
ir.Arith(ast.FloorDiv(loc=None), mod_idx, dim_len))
|
||||||
indices.append(mod_idx)
|
indices.append(mod_idx)
|
||||||
|
|
||||||
elt = arg
|
elt = arg
|
||||||
|
@ -1723,9 +1711,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||||
elt = self.iterable_get(elt, idx)
|
elt = self.iterable_get(elt, idx)
|
||||||
self.append(ir.SetElem(buffer, index, elt))
|
self.append(ir.SetElem(buffer, index, elt))
|
||||||
return self.append(
|
return self.append(
|
||||||
ir.Arith(ast.Add(loc=None), index, ir.Constant(1, length.type)))
|
ir.Arith(ast.Add(loc=None), index,
|
||||||
|
ir.Constant(1, self._size_type)))
|
||||||
|
|
||||||
self._make_loop(
|
self._make_loop(
|
||||||
ir.Constant(0, length.type), lambda index: self.append(
|
ir.Constant(0, self._size_type), lambda index: self.append(
|
||||||
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
ir.Compare(ast.Lt(loc=None), index, num_total_elts)), body_gen)
|
||||||
|
|
||||||
return self.append(ir.Alloc([shape, buffer], node.type))
|
return self.append(ir.Alloc([shape, buffer], node.type))
|
||||||
|
|
|
@ -8,18 +8,28 @@ from .. import asttyped, types, builtins
|
||||||
from .typedtree_printer import TypedtreePrinter
|
from .typedtree_printer import TypedtreePrinter
|
||||||
|
|
||||||
|
|
||||||
def is_rectangular_2d_list(node):
|
def match_rectangular_list(elts):
|
||||||
if not isinstance(node, asttyped.ListT):
|
|
||||||
return False
|
|
||||||
num_elts = None
|
num_elts = None
|
||||||
for e in node.elts:
|
elt_type = None
|
||||||
|
all_child_elts = []
|
||||||
|
|
||||||
|
for e in elts:
|
||||||
|
if elt_type is None:
|
||||||
|
elt_type = e.type.find()
|
||||||
if not isinstance(e, asttyped.ListT):
|
if not isinstance(e, asttyped.ListT):
|
||||||
return False
|
return elt_type, 0
|
||||||
if num_elts is None:
|
if num_elts is None:
|
||||||
num_elts = len(e.elts)
|
num_elts = len(e.elts)
|
||||||
elif num_elts != len(e.elts):
|
elif num_elts != len(e.elts):
|
||||||
return False
|
return elt_type, 0
|
||||||
return True
|
all_child_elts += e.elts
|
||||||
|
|
||||||
|
if not all_child_elts:
|
||||||
|
# This ultimately turned out to be a list (of list, of ...) of empty lists.
|
||||||
|
return elt_type["elt"], 1
|
||||||
|
|
||||||
|
elt, num_dims = match_rectangular_list(all_child_elts)
|
||||||
|
return elt, num_dims + 1
|
||||||
|
|
||||||
|
|
||||||
class Inferencer(algorithm.Visitor):
|
class Inferencer(algorithm.Visitor):
|
||||||
|
@ -710,29 +720,45 @@ class Inferencer(algorithm.Visitor):
|
||||||
"strings currently cannot be constructed", {},
|
"strings currently cannot be constructed", {},
|
||||||
node.loc)
|
node.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
|
elif types.is_builtin(typ, "array"):
|
||||||
if types.is_builtin(typ, "list"):
|
valid_forms = lambda: [
|
||||||
valid_forms = lambda: [
|
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
|
||||||
valid_form("list() -> list(elt='a)"),
|
]
|
||||||
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
|
||||||
]
|
|
||||||
|
|
||||||
self._unify(node.type, builtins.TList(),
|
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||||
node.loc, None)
|
arg, = node.args
|
||||||
elif types.is_builtin(typ, "array"):
|
|
||||||
valid_forms = lambda: [
|
|
||||||
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
|
|
||||||
]
|
|
||||||
|
|
||||||
self._unify(node.type, builtins.TArray(),
|
if builtins.is_iterable(arg.type):
|
||||||
node.loc, None)
|
# KLUDGE: Support multidimensional arary creation if lexically
|
||||||
|
# specified as a rectangular array of lists.
|
||||||
|
elt, num_dims = match_rectangular_list([arg])
|
||||||
|
self._unify(node.type,
|
||||||
|
builtins.TArray(elt, types.TValue(num_dims)),
|
||||||
|
node.loc, arg.loc)
|
||||||
|
elif types.is_var(arg.type):
|
||||||
|
pass # undetermined yet
|
||||||
|
else:
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"this expression has type {type}",
|
||||||
|
{"type": types.TypePrinter().name(arg.type)},
|
||||||
|
arg.loc)
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"the argument of {builtin}() must be of an iterable type",
|
||||||
|
{"builtin": typ.find().name},
|
||||||
|
node.func.loc, notes=[note])
|
||||||
|
self.engine.process(diag)
|
||||||
else:
|
else:
|
||||||
assert False
|
diagnose(valid_forms())
|
||||||
|
elif types.is_builtin(typ, "list"):
|
||||||
|
valid_forms = lambda: [
|
||||||
|
valid_form("list() -> list(elt='a)"),
|
||||||
|
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
||||||
|
]
|
||||||
|
|
||||||
if (types.is_builtin(typ, "list") and len(node.args) == 0 and
|
self._unify(node.type, builtins.TList(), node.loc, None)
|
||||||
len(node.keywords) == 0):
|
|
||||||
# Mimic numpy and don't allow array() (but []).
|
if len(node.args) == 0 and len(node.keywords) == 0:
|
||||||
pass
|
pass # []
|
||||||
elif len(node.args) == 1 and len(node.keywords) == 0:
|
elif len(node.args) == 1 and len(node.keywords) == 0:
|
||||||
arg, = node.args
|
arg, = node.args
|
||||||
|
|
||||||
|
@ -748,14 +774,8 @@ class Inferencer(algorithm.Visitor):
|
||||||
{"typeb": printer.name(typeb)},
|
{"typeb": printer.name(typeb)},
|
||||||
locb)
|
locb)
|
||||||
]
|
]
|
||||||
elt = arg.type.find().params["elt"]
|
|
||||||
if types.is_builtin(typ, "array") and builtins.is_listish(elt):
|
|
||||||
# KLUDGE: Support 2D arary creation if lexically specified
|
|
||||||
# as a rectangular array of lists.
|
|
||||||
if is_rectangular_2d_list(arg):
|
|
||||||
elt = elt.find().params["elt"]
|
|
||||||
self._unify(node.type.find().params["elt"],
|
self._unify(node.type.find().params["elt"],
|
||||||
elt,
|
arg.type.find().params["elt"],
|
||||||
node.loc, arg.loc, makenotes=makenotes)
|
node.loc, arg.loc, makenotes=makenotes)
|
||||||
elif types.is_var(arg.type):
|
elif types.is_var(arg.type):
|
||||||
pass # undetermined yet
|
pass # undetermined yet
|
||||||
|
|
|
@ -1173,7 +1173,7 @@ class LLVMIRGenerator:
|
||||||
if builtins.is_array(collection.type):
|
if builtins.is_array(collection.type):
|
||||||
# Return length of outermost dimension.
|
# Return length of outermost dimension.
|
||||||
shape = self.llbuilder.extract_value(self.map(collection), 0)
|
shape = self.llbuilder.extract_value(self.map(collection), 0)
|
||||||
return self.llbuilder.load(self.llbuilder.extract_value(shape, 0))
|
return self.llbuilder.extract_value(shape, 0)
|
||||||
return self.llbuilder.extract_value(self.map(collection), 1)
|
return self.llbuilder.extract_value(self.map(collection), 1)
|
||||||
elif insn.op in ("printf", "rtio_log"):
|
elif insn.op in ("printf", "rtio_log"):
|
||||||
# We only get integers, floats, pointers and strings here.
|
# We only get integers, floats, pointers and strings here.
|
||||||
|
|
|
@ -50,3 +50,9 @@ class ConstnessValidator(algorithm.Visitor):
|
||||||
node.loc)
|
node.loc)
|
||||||
self.engine.process(diag)
|
self.engine.process(diag)
|
||||||
return
|
return
|
||||||
|
if builtins.is_array(typ):
|
||||||
|
diag = diagnostic.Diagnostic("error",
|
||||||
|
"array attributes cannot be assigned to",
|
||||||
|
{}, node.loc)
|
||||||
|
self.engine.process(diag)
|
||||||
|
return
|
||||||
|
|
|
@ -3,3 +3,7 @@
|
||||||
|
|
||||||
# CHECK-L: ${LINE:+1}: error: array cannot be invoked with the arguments ()
|
# CHECK-L: ${LINE:+1}: error: array cannot be invoked with the arguments ()
|
||||||
a = array()
|
a = array()
|
||||||
|
|
||||||
|
b = array([1, 2, 3])
|
||||||
|
# CHECK-L: ${LINE:+1}: error: array attributes cannot be assigned to
|
||||||
|
b.shape = (5, )
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
ary = array([1, 2, 3])
|
ary = array([1, 2, 3])
|
||||||
assert len(ary) == 3
|
assert len(ary) == 3
|
||||||
assert ary.shape == [3]
|
assert ary.shape == (3,)
|
||||||
# FIXME: Implement ndarray indexing
|
# FIXME: Implement ndarray indexing
|
||||||
# assert [x*x for x in ary] == [1, 4, 9]
|
# assert [x*x for x in ary] == [1, 4, 9]
|
||||||
|
|
||||||
|
@ -11,8 +11,12 @@ assert ary.shape == [3]
|
||||||
empty_array = array([1])
|
empty_array = array([1])
|
||||||
empty_array = array([])
|
empty_array = array([])
|
||||||
assert len(empty_array) == 0
|
assert len(empty_array) == 0
|
||||||
assert empty_array.shape == [0]
|
assert empty_array.shape == (0,)
|
||||||
|
|
||||||
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
matrix = array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||||
assert len(matrix) == 2
|
assert len(matrix) == 2
|
||||||
assert matrix.shape == [2, 3]
|
assert matrix.shape == (2, 3)
|
||||||
|
|
||||||
|
three_tensor = array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
|
||||||
|
assert len(three_tensor) == 1
|
||||||
|
assert three_tensor.shape == (1, 2, 3)
|
||||||
|
|
Loading…
Reference in New Issue