diff --git a/artiq/test/sync_struct.py b/artiq/test/sync_struct.py index 00e4af878..ece90360b 100644 --- a/artiq/test/sync_struct.py +++ b/artiq/test/sync_struct.py @@ -29,18 +29,9 @@ def write_test_data(test_dict): test_dict["finished"] = True -async def start_server(publisher_future, test_dict_future): - test_dict = sync_struct.Notifier(dict()) - publisher = sync_struct.Publisher( - {"test": test_dict}) - await publisher.start(test_address, test_port) - publisher_future.set_result(publisher) - test_dict_future.set_result(test_dict) - - class SyncStructCase(unittest.TestCase): def init_test_dict(self, init): - self.test_dict = init + self.received_dict = init return init def notify(self, mod): @@ -52,29 +43,27 @@ class SyncStructCase(unittest.TestCase): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - def test_recv(self): - loop = self.loop + async def _do_test_recv(self): self.receiving_done = asyncio.Event() - publisher = asyncio.Future() - test_dict = asyncio.Future() - asyncio.ensure_future(start_server(publisher, test_dict)) - loop.run_until_complete(publisher) - loop.run_until_complete(test_dict) - self.publisher = publisher.result() - test_dict = test_dict.result() - test_vector = dict() - write_test_data(test_vector) + test_dict = sync_struct.Notifier(dict()) + publisher = sync_struct.Publisher({"test": test_dict}) + await publisher.start(test_address, test_port) + + subscriber = sync_struct.Subscriber("test", self.init_test_dict, + self.notify) + await subscriber.connect(test_address, test_port) write_test_data(test_dict) - self.subscriber = sync_struct.Subscriber("test", self.init_test_dict, - self.notify) - loop.run_until_complete(self.subscriber.connect(test_address, - test_port)) - loop.run_until_complete(self.receiving_done.wait()) - self.assertEqual(self.test_dict, test_vector) - self.loop.run_until_complete(self.subscriber.close()) - self.loop.run_until_complete(self.publisher.stop()) + await self.receiving_done.wait() + + await subscriber.close() + await publisher.stop() + + self.assertEqual(self.received_dict, test_dict.read) + + def test_recv(self): + self.loop.run_until_complete(self._do_test_recv()) def tearDown(self): self.loop.close()