From 933ea53c775538efebbb2f6105766a88d2671ce6 Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 6 Jul 2016 09:51:57 +0000 Subject: [PATCH] compiler: add basic numpy array support (#424). --- artiq/compiler/builtins.py | 28 +++++++++++++++-- artiq/compiler/embedding.py | 12 ++++++++ artiq/compiler/prelude.py | 1 + .../compiler/transforms/artiq_ir_generator.py | 28 ++++++++++------- artiq/compiler/transforms/inferencer.py | 30 +++++++++++++------ .../compiler/transforms/llvm_ir_generator.py | 13 ++++---- artiq/compiler/types.py | 2 +- artiq/coredevice/comm_generic.py | 3 ++ artiq/runtime/session.c | 8 +++-- artiq/test/coredevice/test_embedding.py | 3 ++ artiq/test/lit/inferencer/unify.py | 3 ++ artiq/test/lit/integration/array.py | 5 ++++ artiq/test/lit/integration/print.py | 3 ++ 13 files changed, 109 insertions(+), 30 deletions(-) create mode 100644 artiq/test/lit/integration/array.py diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 77703ed3c..eae7f9337 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -44,12 +44,12 @@ def TInt32(): def TInt64(): return TInt(types.TValue(64)) -def _int_printer(typ, depth, max_depth): +def _int_printer(typ, printer, depth, max_depth): if types.is_var(typ["width"]): return "numpy.int?" else: return "numpy.int{}".format(types.get_value(typ.find()["width"])) -types.TypePrinter.custom_printers['int'] = _int_printer +types.TypePrinter.custom_printers["int"] = _int_printer class TFloat(types.TMono): def __init__(self): @@ -73,6 +73,16 @@ class TList(types.TMono): elt = types.TVar() super().__init__("list", {"elt": elt}) +class TArray(types.TMono): + def __init__(self, elt=None): + if elt is None: + elt = types.TVar() + super().__init__("array", {"elt": elt}) + +def _array_printer(typ, printer, depth, max_depth): + return "numpy.array(elt={})".format(printer.name(typ["elt"], depth, max_depth)) +types.TypePrinter.custom_printers["array"] = _array_printer + class TRange(types.TMono): def __init__(self, elt=None): if elt is None: @@ -124,6 +134,9 @@ def fn_str(): def fn_list(): return types.TConstructor(TList()) +def fn_array(): + return types.TConstructor(TArray()) + def fn_Exception(): return types.TExceptionConstructor(TException("Exception")) @@ -231,6 +244,15 @@ def is_list(typ, elt=None): else: return types.is_mono(typ, "list") +def is_array(typ, elt=None): + if elt is not None: + return types.is_mono(typ, "array", elt=elt) + else: + return types.is_mono(typ, "array") + +def is_listish(typ, elt=None): + return is_list(typ, elt) or is_array(typ, elt) + def is_range(typ, elt=None): if elt is not None: return types.is_mono(typ, "range", {"elt": elt}) @@ -247,7 +269,7 @@ def is_exception(typ, name=None): def is_iterable(typ): typ = typ.find() return isinstance(typ, types.TMono) and \ - typ.name in ('list', 'range') + typ.name in ('list', 'array', 'range') def get_iterable_elt(typ): if is_iterable(typ): diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 0bf9effd1..8941ea51f 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -187,6 +187,18 @@ class ASTSynthesizer: return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(), begin_loc=begin_loc, end_loc=end_loc, loc=begin_loc.join(end_loc)) + elif isinstance(value, numpy.ndarray): + begin_loc = self._add("numpy.array([") + elts = [] + for index, elt in enumerate(value): + elts.append(self.quote(elt)) + if index < len(value) - 1: + self._add(", ") + end_loc = self._add("])") + + return asttyped.ListT(elts=elts, ctx=None, type=builtins.TArray(), + begin_loc=begin_loc, end_loc=end_loc, + loc=begin_loc.join(end_loc)) elif inspect.isfunction(value) or inspect.ismethod(value) or \ isinstance(value, pytypes.BuiltinFunctionType) or \ isinstance(value, SpecializedFunction): diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index c8dfb3561..8c46176f9 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -12,6 +12,7 @@ def globals(): "int": builtins.fn_int(), "float": builtins.fn_float(), "list": builtins.fn_list(), + "array": builtins.fn_array(), "range": builtins.fn_range(), # Exception constructors diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 0bea60020..549bbef21 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -477,7 +477,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.continue_target = old_continue def iterable_len(self, value, typ=_size_type): - if builtins.is_list(value.type): + if builtins.is_listish(value.type): return self.append(ir.Builtin("len", [value], typ, name="{}.len".format(value.name))) elif builtins.is_range(value.type): @@ -492,7 +492,7 @@ class ARTIQIRGenerator(algorithm.Visitor): def iterable_get(self, value, index): # Assuming the value is within bounds. - if builtins.is_list(value.type): + if builtins.is_listish(value.type): return self.append(ir.GetElem(value, index)) elif builtins.is_range(value.type): start = self.append(ir.GetAttr(value, "start")) @@ -1322,7 +1322,7 @@ class ARTIQIRGenerator(algorithm.Visitor): for index, elt in enumerate(node.right.type.elts): elts.append(self.append(ir.GetAttr(rhs, index))) return self.append(ir.Alloc(elts, node.type)) - elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type): + elif builtins.is_listish(node.left.type) and builtins.is_listish(node.right.type): lhs_length = self.iterable_len(lhs) rhs_length = self.iterable_len(rhs) @@ -1355,9 +1355,9 @@ class ARTIQIRGenerator(algorithm.Visitor): assert False elif isinstance(node.op, ast.Mult): # list * int, int * list lhs, rhs = self.visit(node.left), self.visit(node.right) - if builtins.is_list(lhs.type) and builtins.is_int(rhs.type): + if builtins.is_listish(lhs.type) and builtins.is_int(rhs.type): lst, num = lhs, rhs - elif builtins.is_int(lhs.type) and builtins.is_list(rhs.type): + elif builtins.is_int(lhs.type) and builtins.is_listish(rhs.type): lst, num = rhs, lhs else: assert False @@ -1412,7 +1412,7 @@ class ARTIQIRGenerator(algorithm.Visitor): result = self.append(ir.Select(result, elt_result, ir.Constant(False, builtins.TBool()))) return result - elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type): + elif builtins.is_listish(lhs.type) and builtins.is_listish(rhs.type): head = self.current_block lhs_length = self.iterable_len(lhs) rhs_length = self.iterable_len(rhs) @@ -1606,7 +1606,7 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Coerce(arg, node.type)) else: assert False - elif types.is_builtin(typ, "list"): + elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"): if len(node.args) == 0 and len(node.keywords) == 0: length = ir.Constant(0, builtins.TInt32()) return self.append(ir.Alloc([length], node.type)) @@ -1968,8 +1968,13 @@ class ARTIQIRGenerator(algorithm.Visitor): else: format_string += "%s" args.append(value) - elif builtins.is_list(value.type): - format_string += "["; flush() + elif builtins.is_listish(value.type): + if builtins.is_list(value.type): + format_string += "["; flush() + elif builtins.is_array(value.type): + format_string += "array(["; flush() + else: + assert False length = self.iterable_len(value) last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type))) @@ -1992,7 +1997,10 @@ class ARTIQIRGenerator(algorithm.Visitor): lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), body_gen) - format_string += "]" + if builtins.is_list(value.type): + format_string += "]" + elif builtins.is_array(value.type): + format_string += "])" elif builtins.is_range(value.type): format_string += "range("; flush() diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index f10552cbf..b93e1528b 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -671,14 +671,25 @@ class Inferencer(algorithm.Visitor): pass else: diagnose(valid_forms()) - elif types.is_builtin(typ, "list"): - valid_forms = lambda: [ - valid_form("list() -> list(elt='a)"), - valid_form("list(x:'a) -> list(elt='b) where 'a is iterable") - ] + elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"): + if types.is_builtin(typ, "list"): + valid_forms = lambda: [ + valid_form("list() -> list(elt='a)"), + valid_form("list(x:'a) -> list(elt='b) where 'a is iterable") + ] - self._unify(node.type, builtins.TList(), - node.loc, None) + self._unify(node.type, builtins.TList(), + node.loc, None) + elif types.is_builtin(typ, "array"): + valid_forms = lambda: [ + valid_form("array() -> array(elt='a)"), + valid_form("array(x:'a) -> array(elt='b) where 'a is iterable") + ] + + self._unify(node.type, builtins.TArray(), + node.loc, None) + else: + assert False if len(node.args) == 0 and len(node.keywords) == 0: pass # [] @@ -708,7 +719,8 @@ class Inferencer(algorithm.Visitor): {"type": types.TypePrinter().name(arg.type)}, arg.loc) diag = diagnostic.Diagnostic("error", - "the argument of list() must be of an iterable type", {}, + "the argument of {builtin}() must be of an iterable type", + {"builtin": typ.find().name}, node.func.loc, notes=[note]) self.engine.process(diag) else: @@ -743,7 +755,7 @@ class Inferencer(algorithm.Visitor): if builtins.is_range(arg.type): self._unify(node.type, builtins.get_iterable_elt(arg.type), node.loc, None) - elif builtins.is_list(arg.type): + elif builtins.is_listish(arg.type): # TODO: should be ssize_t-sized self._unify(node.type, builtins.TInt32(), node.loc, None) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 7e8a67e78..ebcd9fc06 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -218,7 +218,7 @@ class LLVMIRGenerator: return lldouble elif builtins.is_str(typ) or ir.is_exn_typeinfo(typ): return llptr - elif builtins.is_list(typ): + elif builtins.is_listish(typ): lleltty = self.llty_of_type(builtins.get_iterable_elt(typ)) return ll.LiteralStructType([lli32, lleltty.as_pointer()]) elif builtins.is_range(typ): @@ -610,7 +610,7 @@ class LLVMIRGenerator: name=insn.name) else: assert False - elif builtins.is_list(insn.type): + elif builtins.is_listish(insn.type): llsize = self.map(insn.operands[0]) llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) llvalue = self.llbuilder.insert_value(llvalue, llsize, 0) @@ -1162,6 +1162,9 @@ class LLVMIRGenerator: elif builtins.is_list(typ): return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), error_handler) + elif builtins.is_array(typ): + return b"a" + self._rpc_tag(builtins.get_iterable_elt(typ), + error_handler) elif builtins.is_range(typ): return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), error_handler) @@ -1405,13 +1408,13 @@ class LLVMIRGenerator: elif builtins.is_str(typ): assert isinstance(value, (str, bytes)) return self.llstr_of_str(value) - elif builtins.is_list(typ): - assert isinstance(value, list) + elif builtins.is_listish(typ): + assert isinstance(value, (list, numpy.ndarray)) elt_type = builtins.get_iterable_elt(typ) llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)]) for i in range(len(value))] lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)), - llelts) + list(llelts)) llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, self.llmodule.scope.deduplicate("quoted.list")) diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 04ba46f75..e1781116e 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -697,7 +697,7 @@ class TypePrinter(object): return "".format(typ.name) elif isinstance(typ, TMono): if typ.name in self.custom_printers: - return self.custom_printers[typ.name](typ, depth + 1, max_depth) + return self.custom_printers[typ.name](typ, self, depth + 1, max_depth) elif typ.params == {}: return typ.name else: diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index 09a5500db..24bccbd40 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -331,6 +331,9 @@ class CommGeneric: elif tag == "l": length = self._read_int32() return [self._receive_rpc_value(embedding_map) for _ in range(length)] + elif tag == "a": + length = self._read_int32() + return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)]) elif tag == "r": start = self._receive_rpc_value(embedding_map) stop = self._receive_rpc_value(embedding_map) diff --git a/artiq/runtime/session.c b/artiq/runtime/session.c index 392566358..c2158d048 100644 --- a/artiq/runtime/session.c +++ b/artiq/runtime/session.c @@ -607,6 +607,7 @@ static void skip_rpc_value(const char **tag) { } case 'l': + case 'a': skip_rpc_value(tag); break; @@ -650,6 +651,7 @@ static int sizeof_rpc_value(const char **tag) return sizeof(char *); case 'l': // list(elt='a) + case 'a': // array(elt='a) skip_rpc_value(tag); return sizeof(struct { int32_t length; struct {} *elements; }); @@ -733,7 +735,8 @@ static int receive_rpc_value(const char **tag, void **slot) break; } - case 'l': { // list(elt='a) + case 'l': // list(elt='a) + case 'a': { // array(elt='a) struct { int32_t length; struct {} *elements; } *list = *slot; list->length = in_packet_int32(); @@ -824,7 +827,8 @@ static int send_rpc_value(const char **tag, void **value) return out_packet_string(*((*(const char***)value)++)); } - case 'l': { // list(elt='a) + case 'l': // list(elt='a) + case 'a': { // array(elt='a) struct { uint32_t length; struct {} *elements; } *list = *value; void *element = list->elements; diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index 8ae19fd47..583b2ca0d 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -41,6 +41,9 @@ class RoundtripTest(ExperimentCase): def test_list(self): self.assertRoundtrip([10]) + def test_array(self): + self.assertRoundtrip(numpy.array([10])) + def test_object(self): obj = object() self.assertRoundtrip(obj) diff --git a/artiq/test/lit/inferencer/unify.py b/artiq/test/lit/inferencer/unify.py index a0dadaf1e..bb7fd9241 100644 --- a/artiq/test/lit/inferencer/unify.py +++ b/artiq/test/lit/inferencer/unify.py @@ -57,6 +57,9 @@ lambda x, y=1: x k = "x" # CHECK-L: k:str +l = array([1]) +# CHECK-L: l:numpy.array(elt=numpy.int?) + IndexError() # CHECK-L: IndexError:():IndexError diff --git a/artiq/test/lit/integration/array.py b/artiq/test/lit/integration/array.py new file mode 100644 index 000000000..aec874497 --- /dev/null +++ b/artiq/test/lit/integration/array.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# REQUIRES: exceptions + +ary = array([1, 2, 3]) +assert [x*x for x in ary] == [1, 4, 9] diff --git a/artiq/test/lit/integration/print.py b/artiq/test/lit/integration/print.py index 06653d43b..eb7fda126 100644 --- a/artiq/test/lit/integration/print.py +++ b/artiq/test/lit/integration/print.py @@ -30,3 +30,6 @@ print([[1, 2], [3]]) # CHECK-L: range(0, 10, 1) print(range(10)) + +# CHECK-L: array([1, 2]) +print(array([1, 2]))