1
0
forked from M-Labs/artiq

compiler: Fix quoting of multi-dimensional arrays

GitHub: Fixes m-labs/artiq#1523.
This commit is contained in:
David Nadlinger 2020-10-20 01:33:14 +02:00
parent d161fd5d84
commit d5f90f6c9f
2 changed files with 58 additions and 12 deletions

View File

@ -1503,6 +1503,17 @@ class LLVMIRGenerator:
return llcall
def _quote_listish_to_llglobal(self, value, elt_type, path, kind_name):
llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)])
for i in range(len(value))]
lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)),
list(llelts))
name = self.llmodule.scope.deduplicate("quoted.{}".format(kind_name))
llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, name)
llglobal.initializer = lleltsary
llglobal.linkage = "private"
return llglobal.bitcast(lleltsary.type.element.as_pointer())
def _quote(self, value, typ, path):
value_id = id(value)
if value_id in self.llobject_map:
@ -1579,21 +1590,19 @@ class LLVMIRGenerator:
llstr = self.llstr_of_str(as_bytes)
llconst = ll.Constant(llty, [llstr, ll.Constant(lli32, len(as_bytes))])
return llconst
elif builtins.is_array(typ):
assert isinstance(value, numpy.ndarray), fail_msg
typ = typ.find()
assert len(value.shape) == typ["num_dims"].find().value
flattened = value.reshape((-1,))
lleltsptr = self._quote_listish_to_llglobal(flattened, typ["elt"], path, "array")
llshape = ll.Constant.literal_struct([ll.Constant(lli32, s) for s in value.shape])
return ll.Constant(llty, [lleltsptr, llshape])
elif builtins.is_listish(typ):
assert isinstance(value, (list, numpy.ndarray)), fail_msg
elt_type = builtins.get_iterable_elt(typ)
llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)])
for i in range(len(value))]
lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)),
list(llelts))
name = self.llmodule.scope.deduplicate("quoted.{}".format(typ.name))
llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, name)
llglobal.initializer = lleltsary
llglobal.linkage = "private"
lleltsptr = llglobal.bitcast(lleltsary.type.element.as_pointer())
llconst = ll.Constant(llty, [lleltsptr, ll.Constant(lli32, len(llelts))])
lleltsptr = self._quote_listish_to_llglobal(value, elt_type, path, typ.find().name)
llconst = ll.Constant(llty, [lleltsptr, ll.Constant(lli32, len(value))])
return llconst
elif types.is_tuple(typ):
assert isinstance(value, tuple), fail_msg

View File

@ -423,3 +423,40 @@ class ListTupleTest(ExperimentCase):
def test_empty_list(self):
self.create(_EmptyList).run()
class _ArrayQuoting(EnvExperiment):
def build(self):
self.setattr_device("core")
self.vec_i32 = np.array([0, 1], dtype=np.int32)
self.mat_i64 = np.array([[0, 1], [2, 3]], dtype=np.int64)
self.arr_f64 = np.array([[[0.0, 1.0], [2.0, 3.0]],
[[4.0, 5.0], [6.0, 7.0]]])
self.strs = np.array(["foo", "bar"])
@kernel
def run(self):
assert self.vec_i32[0] == 0
assert self.vec_i32[1] == 1
assert self.mat_i64[0, 0] == 0
assert self.mat_i64[0, 1] == 1
assert self.mat_i64[1, 0] == 2
assert self.mat_i64[1, 1] == 3
assert self.arr_f64[0, 0, 0] == 0.0
assert self.arr_f64[0, 0, 1] == 1.0
assert self.arr_f64[0, 1, 0] == 2.0
assert self.arr_f64[0, 1, 1] == 3.0
assert self.arr_f64[1, 0, 0] == 4.0
assert self.arr_f64[1, 0, 1] == 5.0
assert self.arr_f64[1, 1, 0] == 6.0
assert self.arr_f64[1, 1, 1] == 7.0
assert self.strs[0] == "foo"
assert self.strs[1] == "bar"
class ArrayQuotingTest(ExperimentCase):
def test_quoting(self):
self.create(_ArrayQuoting).run()