diff --git a/artiq/protocols/broadcast.py b/artiq/protocols/broadcast.py index 9cffc6b78..a517ebb86 100644 --- a/artiq/protocols/broadcast.py +++ b/artiq/protocols/broadcast.py @@ -9,11 +9,12 @@ _init_string = b"ARTIQ broadcast\n" class Receiver: - def __init__(self, name, notify_cb): + def __init__(self, name, notify_cb, disconnect_cb=None): self.name = name if not isinstance(notify_cb, list): notify_cb = [notify_cb] self.notify_cbs = notify_cb + self.disconnect_cb = disconnect_cb async def connect(self, host, port): self.reader, self.writer = \ @@ -29,6 +30,7 @@ class Receiver: raise async def close(self): + self.disconnect_cb = None try: self.receive_task.cancel() try: @@ -41,15 +43,19 @@ class Receiver: del self.writer async def _receive_cr(self): - target = None - while True: - line = await self.reader.readline() - if not line: - return - obj = pyon.decode(line.decode()) + try: + target = None + while True: + line = await self.reader.readline() + if not line: + return + obj = pyon.decode(line.decode()) - for notify_cb in self.notify_cbs: - notify_cb(obj) + for notify_cb in self.notify_cbs: + notify_cb(obj) + finally: + if self.disconnect_cb is not None: + self.disconnect_cb() class Broadcaster(AsyncioServer): diff --git a/artiq/protocols/sync_struct.py b/artiq/protocols/sync_struct.py index 59bced790..b6dfdac92 100644 --- a/artiq/protocols/sync_struct.py +++ b/artiq/protocols/sync_struct.py @@ -52,8 +52,10 @@ class Subscriber: from the publisher. The mod is passed as parameter. The function is called after the mod has been processed. A list of functions may also be used, and they will be called in turn. + :param disconnect_cb: An optional function called when disconnection happens + from external causes (i.e. not when ``close`` is called). """ - def __init__(self, notifier_name, target_builder, notify_cb=None): + def __init__(self, notifier_name, target_builder, notify_cb=None, disconnect_cb=None): self.notifier_name = notifier_name self.target_builder = target_builder if notify_cb is None: @@ -61,6 +63,7 @@ class Subscriber: if not isinstance(notify_cb, list): notify_cb = [notify_cb] self.notify_cbs = notify_cb + self.disconnect_cb = disconnect_cb async def connect(self, host, port, before_receive_cb=None): self.reader, self.writer = \ @@ -78,6 +81,7 @@ class Subscriber: raise async def close(self): + self.disconnect_cb = None try: self.receive_task.cancel() try: @@ -90,20 +94,24 @@ class Subscriber: del self.writer async def _receive_cr(self): - target = None - while True: - line = await self.reader.readline() - if not line: - return - mod = pyon.decode(line.decode()) + try: + target = None + while True: + line = await self.reader.readline() + if not line: + return + mod = pyon.decode(line.decode()) - if mod["action"] == "init": - target = self.target_builder(mod["struct"]) - else: - process_mod(target, mod) + if mod["action"] == "init": + target = self.target_builder(mod["struct"]) + else: + process_mod(target, mod) - for notify_cb in self.notify_cbs: - notify_cb(mod) + for notify_cb in self.notify_cbs: + notify_cb(mod) + finally: + if self.disconnect_cb is not None: + self.disconnect_cb() class Notifier: