From d5f90f6c9faa1d52e5a1999db1a9eaf0b812d38b Mon Sep 17 00:00:00 2001 From: David Nadlinger Date: Tue, 20 Oct 2020 01:33:14 +0200 Subject: [PATCH] compiler: Fix quoting of multi-dimensional arrays GitHub: Fixes m-labs/artiq#1523. --- .../compiler/transforms/llvm_ir_generator.py | 33 +++++++++++------ artiq/test/coredevice/test_embedding.py | 37 +++++++++++++++++++ 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 31384e578..003de3c1a 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -1503,6 +1503,17 @@ class LLVMIRGenerator: return llcall + def _quote_listish_to_llglobal(self, value, elt_type, path, kind_name): + llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)]) + for i in range(len(value))] + lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)), + list(llelts)) + name = self.llmodule.scope.deduplicate("quoted.{}".format(kind_name)) + llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, name) + llglobal.initializer = lleltsary + llglobal.linkage = "private" + return llglobal.bitcast(lleltsary.type.element.as_pointer()) + def _quote(self, value, typ, path): value_id = id(value) if value_id in self.llobject_map: @@ -1579,21 +1590,19 @@ class LLVMIRGenerator: llstr = self.llstr_of_str(as_bytes) llconst = ll.Constant(llty, [llstr, ll.Constant(lli32, len(as_bytes))]) return llconst + elif builtins.is_array(typ): + assert isinstance(value, numpy.ndarray), fail_msg + typ = typ.find() + assert len(value.shape) == typ["num_dims"].find().value + flattened = value.reshape((-1,)) + lleltsptr = self._quote_listish_to_llglobal(flattened, typ["elt"], path, "array") + llshape = ll.Constant.literal_struct([ll.Constant(lli32, s) for s in value.shape]) + return ll.Constant(llty, [lleltsptr, llshape]) elif builtins.is_listish(typ): assert isinstance(value, (list, numpy.ndarray)), fail_msg elt_type = builtins.get_iterable_elt(typ) - llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)]) - for i in range(len(value))] - lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)), - list(llelts)) - - name = self.llmodule.scope.deduplicate("quoted.{}".format(typ.name)) - llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, name) - llglobal.initializer = lleltsary - llglobal.linkage = "private" - - lleltsptr = llglobal.bitcast(lleltsary.type.element.as_pointer()) - llconst = ll.Constant(llty, [lleltsptr, ll.Constant(lli32, len(llelts))]) + lleltsptr = self._quote_listish_to_llglobal(value, elt_type, path, typ.find().name) + llconst = ll.Constant(llty, [lleltsptr, ll.Constant(lli32, len(value))]) return llconst elif types.is_tuple(typ): assert isinstance(value, tuple), fail_msg diff --git a/artiq/test/coredevice/test_embedding.py b/artiq/test/coredevice/test_embedding.py index c32ba2c55..d944198c9 100644 --- a/artiq/test/coredevice/test_embedding.py +++ b/artiq/test/coredevice/test_embedding.py @@ -423,3 +423,40 @@ class ListTupleTest(ExperimentCase): def test_empty_list(self): self.create(_EmptyList).run() + + +class _ArrayQuoting(EnvExperiment): + def build(self): + self.setattr_device("core") + self.vec_i32 = np.array([0, 1], dtype=np.int32) + self.mat_i64 = np.array([[0, 1], [2, 3]], dtype=np.int64) + self.arr_f64 = np.array([[[0.0, 1.0], [2.0, 3.0]], + [[4.0, 5.0], [6.0, 7.0]]]) + self.strs = np.array(["foo", "bar"]) + + @kernel + def run(self): + assert self.vec_i32[0] == 0 + assert self.vec_i32[1] == 1 + + assert self.mat_i64[0, 0] == 0 + assert self.mat_i64[0, 1] == 1 + assert self.mat_i64[1, 0] == 2 + assert self.mat_i64[1, 1] == 3 + + assert self.arr_f64[0, 0, 0] == 0.0 + assert self.arr_f64[0, 0, 1] == 1.0 + assert self.arr_f64[0, 1, 0] == 2.0 + assert self.arr_f64[0, 1, 1] == 3.0 + assert self.arr_f64[1, 0, 0] == 4.0 + assert self.arr_f64[1, 0, 1] == 5.0 + assert self.arr_f64[1, 1, 0] == 6.0 + assert self.arr_f64[1, 1, 1] == 7.0 + + assert self.strs[0] == "foo" + assert self.strs[1] == "bar" + + +class ArrayQuotingTest(ExperimentCase): + def test_quoting(self): + self.create(_ArrayQuoting).run()