forked from M-Labs/artiq
1
0
Fork 0

Fixup 4359a437 (tuples of lists), add regression tests

This commit is contained in:
David Nadlinger 2018-07-10 01:18:51 +01:00
parent edc314524c
commit 768b970deb
3 changed files with 42 additions and 4 deletions

View File

@ -311,11 +311,9 @@ def is_collection(typ):
types.is_mono(typ, "list")
def is_allocated(typ):
if types.is_tuple(typ):
return any(is_allocated(e.find()) for e in typ.elts)
return not (is_none(typ) or is_bool(typ) or is_int(typ) or
is_float(typ) or is_range(typ) or
types._is_pointer(typ) or types.is_function(typ) or
types.is_c_function(typ) or types.is_rpc(typ) or
types.is_method(typ) or
types.is_method(typ) or types.is_tuple(typ) or
types.is_value(typ))

View File

@ -1384,7 +1384,7 @@ class LLVMIRGenerator:
self.llbuilder.position_at_end(lltail)
llret = self.llbuilder.load(llslot, name="rpc.ret")
if not builtins.is_allocated(fun_type.ret):
if not fun_type.ret.fold(False, lambda r, t: r or builtins.is_allocated(t)):
# We didn't allocate anything except the slot for the value itself.
# Don't waste stack space.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])

View File

@ -58,6 +58,9 @@ class RoundtripTest(ExperimentCase):
def test_object_list(self):
self.assertRoundtrip([object(), object()])
def test_list_tuple(self):
self.assertRoundtrip(([1, 2], [3, 4]))
class _DefaultArg(EnvExperiment):
def build(self):
@ -296,3 +299,40 @@ class LargePayloadTest(ExperimentCase):
def test_1MB(self):
exp = self.create(_Payload1MB)
exp.run()
class _ListTuple(EnvExperiment):
def build(self):
self.setattr_device("core")
@kernel
def run(self):
# Make sure lifetime for the array data in tuples of lists is managed
# correctly. This is written in a somewhat convoluted fashion to provoke
# memory corruption even in the face of compiler optimizations.
for _ in range(self.get_num_iters()):
a, b = self.get_values(0, 1, 32)
c, d = self.get_values(2, 3, 64)
self.verify(a)
self.verify(c)
self.verify(b)
self.verify(d)
@kernel
def verify(self, data):
for i in range(len(data)):
if data[i] != data[0] + i:
raise ValueError
def get_num_iters(self) -> TInt32:
return 2
def get_values(self, base_a, base_b, n) -> TTuple([TList(TInt32), TList(TInt32)]):
return [numpy.int32(base_a + i) for i in range(n)], \
[numpy.int32(base_b + i) for i in range(n)]
class ListTupleTest(ExperimentCase):
def test_list_tuple(self):
exp = self.create(_ListTuple)
exp.run()