mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-25 03:08:27 +08:00
compiler: add basic numpy array support (#424).
This commit is contained in:
parent
906db876a6
commit
933ea53c77
@ -44,12 +44,12 @@ def TInt32():
|
||||
def TInt64():
|
||||
return TInt(types.TValue(64))
|
||||
|
||||
def _int_printer(typ, depth, max_depth):
|
||||
def _int_printer(typ, printer, depth, max_depth):
|
||||
if types.is_var(typ["width"]):
|
||||
return "numpy.int?"
|
||||
else:
|
||||
return "numpy.int{}".format(types.get_value(typ.find()["width"]))
|
||||
types.TypePrinter.custom_printers['int'] = _int_printer
|
||||
types.TypePrinter.custom_printers["int"] = _int_printer
|
||||
|
||||
class TFloat(types.TMono):
|
||||
def __init__(self):
|
||||
@ -73,6 +73,16 @@ class TList(types.TMono):
|
||||
elt = types.TVar()
|
||||
super().__init__("list", {"elt": elt})
|
||||
|
||||
class TArray(types.TMono):
|
||||
def __init__(self, elt=None):
|
||||
if elt is None:
|
||||
elt = types.TVar()
|
||||
super().__init__("array", {"elt": elt})
|
||||
|
||||
def _array_printer(typ, printer, depth, max_depth):
|
||||
return "numpy.array(elt={})".format(printer.name(typ["elt"], depth, max_depth))
|
||||
types.TypePrinter.custom_printers["array"] = _array_printer
|
||||
|
||||
class TRange(types.TMono):
|
||||
def __init__(self, elt=None):
|
||||
if elt is None:
|
||||
@ -124,6 +134,9 @@ def fn_str():
|
||||
def fn_list():
|
||||
return types.TConstructor(TList())
|
||||
|
||||
def fn_array():
|
||||
return types.TConstructor(TArray())
|
||||
|
||||
def fn_Exception():
|
||||
return types.TExceptionConstructor(TException("Exception"))
|
||||
|
||||
@ -231,6 +244,15 @@ def is_list(typ, elt=None):
|
||||
else:
|
||||
return types.is_mono(typ, "list")
|
||||
|
||||
def is_array(typ, elt=None):
|
||||
if elt is not None:
|
||||
return types.is_mono(typ, "array", elt=elt)
|
||||
else:
|
||||
return types.is_mono(typ, "array")
|
||||
|
||||
def is_listish(typ, elt=None):
|
||||
return is_list(typ, elt) or is_array(typ, elt)
|
||||
|
||||
def is_range(typ, elt=None):
|
||||
if elt is not None:
|
||||
return types.is_mono(typ, "range", {"elt": elt})
|
||||
@ -247,7 +269,7 @@ def is_exception(typ, name=None):
|
||||
def is_iterable(typ):
|
||||
typ = typ.find()
|
||||
return isinstance(typ, types.TMono) and \
|
||||
typ.name in ('list', 'range')
|
||||
typ.name in ('list', 'array', 'range')
|
||||
|
||||
def get_iterable_elt(typ):
|
||||
if is_iterable(typ):
|
||||
|
@ -187,6 +187,18 @@ class ASTSynthesizer:
|
||||
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
|
||||
begin_loc=begin_loc, end_loc=end_loc,
|
||||
loc=begin_loc.join(end_loc))
|
||||
elif isinstance(value, numpy.ndarray):
|
||||
begin_loc = self._add("numpy.array([")
|
||||
elts = []
|
||||
for index, elt in enumerate(value):
|
||||
elts.append(self.quote(elt))
|
||||
if index < len(value) - 1:
|
||||
self._add(", ")
|
||||
end_loc = self._add("])")
|
||||
|
||||
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TArray(),
|
||||
begin_loc=begin_loc, end_loc=end_loc,
|
||||
loc=begin_loc.join(end_loc))
|
||||
elif inspect.isfunction(value) or inspect.ismethod(value) or \
|
||||
isinstance(value, pytypes.BuiltinFunctionType) or \
|
||||
isinstance(value, SpecializedFunction):
|
||||
|
@ -12,6 +12,7 @@ def globals():
|
||||
"int": builtins.fn_int(),
|
||||
"float": builtins.fn_float(),
|
||||
"list": builtins.fn_list(),
|
||||
"array": builtins.fn_array(),
|
||||
"range": builtins.fn_range(),
|
||||
|
||||
# Exception constructors
|
||||
|
@ -477,7 +477,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.continue_target = old_continue
|
||||
|
||||
def iterable_len(self, value, typ=_size_type):
|
||||
if builtins.is_list(value.type):
|
||||
if builtins.is_listish(value.type):
|
||||
return self.append(ir.Builtin("len", [value], typ,
|
||||
name="{}.len".format(value.name)))
|
||||
elif builtins.is_range(value.type):
|
||||
@ -492,7 +492,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
def iterable_get(self, value, index):
|
||||
# Assuming the value is within bounds.
|
||||
if builtins.is_list(value.type):
|
||||
if builtins.is_listish(value.type):
|
||||
return self.append(ir.GetElem(value, index))
|
||||
elif builtins.is_range(value.type):
|
||||
start = self.append(ir.GetAttr(value, "start"))
|
||||
@ -1322,7 +1322,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
for index, elt in enumerate(node.right.type.elts):
|
||||
elts.append(self.append(ir.GetAttr(rhs, index)))
|
||||
return self.append(ir.Alloc(elts, node.type))
|
||||
elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type):
|
||||
elif builtins.is_listish(node.left.type) and builtins.is_listish(node.right.type):
|
||||
lhs_length = self.iterable_len(lhs)
|
||||
rhs_length = self.iterable_len(rhs)
|
||||
|
||||
@ -1355,9 +1355,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
assert False
|
||||
elif isinstance(node.op, ast.Mult): # list * int, int * list
|
||||
lhs, rhs = self.visit(node.left), self.visit(node.right)
|
||||
if builtins.is_list(lhs.type) and builtins.is_int(rhs.type):
|
||||
if builtins.is_listish(lhs.type) and builtins.is_int(rhs.type):
|
||||
lst, num = lhs, rhs
|
||||
elif builtins.is_int(lhs.type) and builtins.is_list(rhs.type):
|
||||
elif builtins.is_int(lhs.type) and builtins.is_listish(rhs.type):
|
||||
lst, num = rhs, lhs
|
||||
else:
|
||||
assert False
|
||||
@ -1412,7 +1412,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
result = self.append(ir.Select(result, elt_result,
|
||||
ir.Constant(False, builtins.TBool())))
|
||||
return result
|
||||
elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type):
|
||||
elif builtins.is_listish(lhs.type) and builtins.is_listish(rhs.type):
|
||||
head = self.current_block
|
||||
lhs_length = self.iterable_len(lhs)
|
||||
rhs_length = self.iterable_len(rhs)
|
||||
@ -1606,7 +1606,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return self.append(ir.Coerce(arg, node.type))
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "list"):
|
||||
elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
|
||||
if len(node.args) == 0 and len(node.keywords) == 0:
|
||||
length = ir.Constant(0, builtins.TInt32())
|
||||
return self.append(ir.Alloc([length], node.type))
|
||||
@ -1968,8 +1968,13 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else:
|
||||
format_string += "%s"
|
||||
args.append(value)
|
||||
elif builtins.is_list(value.type):
|
||||
format_string += "["; flush()
|
||||
elif builtins.is_listish(value.type):
|
||||
if builtins.is_list(value.type):
|
||||
format_string += "["; flush()
|
||||
elif builtins.is_array(value.type):
|
||||
format_string += "array(["; flush()
|
||||
else:
|
||||
assert False
|
||||
|
||||
length = self.iterable_len(value)
|
||||
last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type)))
|
||||
@ -1992,7 +1997,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)),
|
||||
body_gen)
|
||||
|
||||
format_string += "]"
|
||||
if builtins.is_list(value.type):
|
||||
format_string += "]"
|
||||
elif builtins.is_array(value.type):
|
||||
format_string += "])"
|
||||
elif builtins.is_range(value.type):
|
||||
format_string += "range("; flush()
|
||||
|
||||
|
@ -671,14 +671,25 @@ class Inferencer(algorithm.Visitor):
|
||||
pass
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "list"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("list() -> list(elt='a)"),
|
||||
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
||||
]
|
||||
elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
|
||||
if types.is_builtin(typ, "list"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("list() -> list(elt='a)"),
|
||||
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TList(),
|
||||
node.loc, None)
|
||||
self._unify(node.type, builtins.TList(),
|
||||
node.loc, None)
|
||||
elif types.is_builtin(typ, "array"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("array() -> array(elt='a)"),
|
||||
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TArray(),
|
||||
node.loc, None)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if len(node.args) == 0 and len(node.keywords) == 0:
|
||||
pass # []
|
||||
@ -708,7 +719,8 @@ class Inferencer(algorithm.Visitor):
|
||||
{"type": types.TypePrinter().name(arg.type)},
|
||||
arg.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"the argument of list() must be of an iterable type", {},
|
||||
"the argument of {builtin}() must be of an iterable type",
|
||||
{"builtin": typ.find().name},
|
||||
node.func.loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
else:
|
||||
@ -743,7 +755,7 @@ class Inferencer(algorithm.Visitor):
|
||||
if builtins.is_range(arg.type):
|
||||
self._unify(node.type, builtins.get_iterable_elt(arg.type),
|
||||
node.loc, None)
|
||||
elif builtins.is_list(arg.type):
|
||||
elif builtins.is_listish(arg.type):
|
||||
# TODO: should be ssize_t-sized
|
||||
self._unify(node.type, builtins.TInt32(),
|
||||
node.loc, None)
|
||||
|
@ -218,7 +218,7 @@ class LLVMIRGenerator:
|
||||
return lldouble
|
||||
elif builtins.is_str(typ) or ir.is_exn_typeinfo(typ):
|
||||
return llptr
|
||||
elif builtins.is_list(typ):
|
||||
elif builtins.is_listish(typ):
|
||||
lleltty = self.llty_of_type(builtins.get_iterable_elt(typ))
|
||||
return ll.LiteralStructType([lli32, lleltty.as_pointer()])
|
||||
elif builtins.is_range(typ):
|
||||
@ -610,7 +610,7 @@ class LLVMIRGenerator:
|
||||
name=insn.name)
|
||||
else:
|
||||
assert False
|
||||
elif builtins.is_list(insn.type):
|
||||
elif builtins.is_listish(insn.type):
|
||||
llsize = self.map(insn.operands[0])
|
||||
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
|
||||
llvalue = self.llbuilder.insert_value(llvalue, llsize, 0)
|
||||
@ -1162,6 +1162,9 @@ class LLVMIRGenerator:
|
||||
elif builtins.is_list(typ):
|
||||
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
|
||||
error_handler)
|
||||
elif builtins.is_array(typ):
|
||||
return b"a" + self._rpc_tag(builtins.get_iterable_elt(typ),
|
||||
error_handler)
|
||||
elif builtins.is_range(typ):
|
||||
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
|
||||
error_handler)
|
||||
@ -1405,13 +1408,13 @@ class LLVMIRGenerator:
|
||||
elif builtins.is_str(typ):
|
||||
assert isinstance(value, (str, bytes))
|
||||
return self.llstr_of_str(value)
|
||||
elif builtins.is_list(typ):
|
||||
assert isinstance(value, list)
|
||||
elif builtins.is_listish(typ):
|
||||
assert isinstance(value, (list, numpy.ndarray))
|
||||
elt_type = builtins.get_iterable_elt(typ)
|
||||
llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)])
|
||||
for i in range(len(value))]
|
||||
lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)),
|
||||
llelts)
|
||||
list(llelts))
|
||||
|
||||
llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type,
|
||||
self.llmodule.scope.deduplicate("quoted.list"))
|
||||
|
@ -697,7 +697,7 @@ class TypePrinter(object):
|
||||
return "<instance {} {{}}>".format(typ.name)
|
||||
elif isinstance(typ, TMono):
|
||||
if typ.name in self.custom_printers:
|
||||
return self.custom_printers[typ.name](typ, depth + 1, max_depth)
|
||||
return self.custom_printers[typ.name](typ, self, depth + 1, max_depth)
|
||||
elif typ.params == {}:
|
||||
return typ.name
|
||||
else:
|
||||
|
@ -331,6 +331,9 @@ class CommGeneric:
|
||||
elif tag == "l":
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
|
||||
elif tag == "a":
|
||||
length = self._read_int32()
|
||||
return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)])
|
||||
elif tag == "r":
|
||||
start = self._receive_rpc_value(embedding_map)
|
||||
stop = self._receive_rpc_value(embedding_map)
|
||||
|
@ -607,6 +607,7 @@ static void skip_rpc_value(const char **tag) {
|
||||
}
|
||||
|
||||
case 'l':
|
||||
case 'a':
|
||||
skip_rpc_value(tag);
|
||||
break;
|
||||
|
||||
@ -650,6 +651,7 @@ static int sizeof_rpc_value(const char **tag)
|
||||
return sizeof(char *);
|
||||
|
||||
case 'l': // list(elt='a)
|
||||
case 'a': // array(elt='a)
|
||||
skip_rpc_value(tag);
|
||||
return sizeof(struct { int32_t length; struct {} *elements; });
|
||||
|
||||
@ -733,7 +735,8 @@ static int receive_rpc_value(const char **tag, void **slot)
|
||||
break;
|
||||
}
|
||||
|
||||
case 'l': { // list(elt='a)
|
||||
case 'l': // list(elt='a)
|
||||
case 'a': { // array(elt='a)
|
||||
struct { int32_t length; struct {} *elements; } *list = *slot;
|
||||
list->length = in_packet_int32();
|
||||
|
||||
@ -824,7 +827,8 @@ static int send_rpc_value(const char **tag, void **value)
|
||||
return out_packet_string(*((*(const char***)value)++));
|
||||
}
|
||||
|
||||
case 'l': { // list(elt='a)
|
||||
case 'l': // list(elt='a)
|
||||
case 'a': { // array(elt='a)
|
||||
struct { uint32_t length; struct {} *elements; } *list = *value;
|
||||
void *element = list->elements;
|
||||
|
||||
|
@ -41,6 +41,9 @@ class RoundtripTest(ExperimentCase):
|
||||
def test_list(self):
|
||||
self.assertRoundtrip([10])
|
||||
|
||||
def test_array(self):
|
||||
self.assertRoundtrip(numpy.array([10]))
|
||||
|
||||
def test_object(self):
|
||||
obj = object()
|
||||
self.assertRoundtrip(obj)
|
||||
|
@ -57,6 +57,9 @@ lambda x, y=1: x
|
||||
k = "x"
|
||||
# CHECK-L: k:str
|
||||
|
||||
l = array([1])
|
||||
# CHECK-L: l:numpy.array(elt=numpy.int?)
|
||||
|
||||
IndexError()
|
||||
# CHECK-L: IndexError:<constructor IndexError {}>():IndexError
|
||||
|
||||
|
5
artiq/test/lit/integration/array.py
Normal file
5
artiq/test/lit/integration/array.py
Normal file
@ -0,0 +1,5 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.jit %s
|
||||
# REQUIRES: exceptions
|
||||
|
||||
ary = array([1, 2, 3])
|
||||
assert [x*x for x in ary] == [1, 4, 9]
|
@ -30,3 +30,6 @@ print([[1, 2], [3]])
|
||||
|
||||
# CHECK-L: range(0, 10, 1)
|
||||
print(range(10))
|
||||
|
||||
# CHECK-L: array([1, 2])
|
||||
print(array([1, 2]))
|
||||
|
Loading…
Reference in New Issue
Block a user