diff --git a/artiq/devices/ctlmgr.py b/artiq/devices/ctlmgr.py new file mode 100644 index 000000000..2c50c25df --- /dev/null +++ b/artiq/devices/ctlmgr.py @@ -0,0 +1,254 @@ +import asyncio +import logging +import subprocess +import shlex +import socket + +from artiq.protocols.sync_struct import Subscriber +from artiq.protocols.pc_rpc import AsyncioClient +from artiq.protocols.logging import parse_log_message, log_with_name +from artiq.tools import Condition, TaskObject + + +logger = logging.getLogger(__name__) + + +class Controller: + def __init__(self, name, ddb_entry): + self.name = name + self.command = ddb_entry["command"] + self.retry_timer = ddb_entry.get("retry_timer", 5) + self.retry_timer_backoff = ddb_entry.get("retry_timer_backoff", 1.1) + + self.host = ddb_entry["host"] + self.port = ddb_entry["port"] + self.ping_timer = ddb_entry.get("ping_timer", 30) + self.ping_timeout = ddb_entry.get("ping_timeout", 30) + self.term_timeout = ddb_entry.get("term_timeout", 30) + + self.retry_timer_cur = self.retry_timer + self.retry_now = Condition() + self.process = None + self.launch_task = asyncio.Task(self.launcher()) + + async def end(self): + self.launch_task.cancel() + await asyncio.wait_for(self.launch_task, None) + + async def _call_controller(self, method): + remote = AsyncioClient() + await remote.connect_rpc(self.host, self.port, None) + try: + targets, _ = remote.get_rpc_id() + remote.select_rpc_target(targets[0]) + r = await getattr(remote, method)() + finally: + remote.close_rpc() + return r + + async def _ping(self): + try: + ok = await asyncio.wait_for(self._call_controller("ping"), + self.ping_timeout) + if ok: + self.retry_timer_cur = self.retry_timer + return ok + except: + return False + + async def _wait_and_ping(self): + while True: + try: + await asyncio.wait_for(self.process.wait(), + self.ping_timer) + except asyncio.TimeoutError: + logger.debug("pinging controller %s", self.name) + ok = await self._ping() + if not ok: + logger.warning("Controller %s ping failed", self.name) + await self._terminate() + return + else: + break + + async def forward_logs(self, stream): + source = "controller({})".format(self.name) + while True: + try: + entry = (await stream.readline()) + if not entry: + break + entry = entry[:-1] + level, name, message = parse_log_message(entry.decode()) + log_with_name(name, level, message, extra={"source": source}) + except: + logger.debug("exception in log forwarding", exc_info=True) + break + logger.debug("stopped log forwarding of stream %s of %s", + stream, self.name) + + async def launcher(self): + try: + while True: + logger.info("Starting controller %s with command: %s", + self.name, self.command) + try: + self.process = await asyncio.create_subprocess_exec( + *shlex.split(self.command), + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + asyncio.ensure_future(self.forward_logs( + self.process.stdout)) + asyncio.ensure_future(self.forward_logs( + self.process.stderr)) + await self._wait_and_ping() + except FileNotFoundError: + logger.warning("Controller %s failed to start", self.name) + else: + logger.warning("Controller %s exited", self.name) + logger.warning("Restarting in %.1f seconds", + self.retry_timer_cur) + try: + await asyncio.wait_for(self.retry_now.wait(), + self.retry_timer_cur) + except asyncio.TimeoutError: + pass + self.retry_timer_cur *= self.retry_timer_backoff + except asyncio.CancelledError: + await self._terminate() + + async def _terminate(self): + logger.info("Terminating controller %s", self.name) + if self.process is not None and self.process.returncode is None: + try: + await asyncio.wait_for(self._call_controller("terminate"), + self.term_timeout) + except: + logger.warning("Controller %s did not respond to terminate " + "command, killing", self.name) + try: + self.process.kill() + except ProcessLookupError: + pass + try: + await asyncio.wait_for(self.process.wait(), + self.term_timeout) + except: + logger.warning("Controller %s failed to exit, killing", + self.name) + try: + self.process.kill() + except ProcessLookupError: + pass + await self.process.wait() + logger.debug("Controller %s terminated", self.name) + + +def get_ip_addresses(host): + try: + addrinfo = socket.getaddrinfo(host, None) + except: + return set() + return {info[4][0] for info in addrinfo} + + +class Controllers: + def __init__(self): + self.host_filter = None + self.active_or_queued = set() + self.queue = asyncio.Queue() + self.active = dict() + self.process_task = asyncio.Task(self._process()) + + async def _process(self): + while True: + action, param = await self.queue.get() + if action == "set": + k, ddb_entry = param + if k in self.active: + await self.active[k].end() + self.active[k] = Controller(k, ddb_entry) + elif action == "del": + await self.active[param].end() + del self.active[param] + else: + raise ValueError + + def __setitem__(self, k, v): + if (isinstance(v, dict) and v["type"] == "controller" and + self.host_filter in get_ip_addresses(v["host"])): + v["command"] = v["command"].format(name=k, + bind=self.host_filter, + port=v["port"]) + self.queue.put_nowait(("set", (k, v))) + self.active_or_queued.add(k) + + def __delitem__(self, k): + if k in self.active_or_queued: + self.queue.put_nowait(("del", k)) + self.active_or_queued.remove(k) + + def delete_all(self): + for name in set(self.active_or_queued): + del self[name] + + async def shutdown(self): + self.process_task.cancel() + for c in self.active.values(): + await c.end() + + +class ControllerDB: + def __init__(self): + self.current_controllers = Controllers() + + def set_host_filter(self, host_filter): + self.current_controllers.host_filter = host_filter + + def sync_struct_init(self, init): + if self.current_controllers is not None: + self.current_controllers.delete_all() + for k, v in init.items(): + self.current_controllers[k] = v + return self.current_controllers + + +class ControllerManager(TaskObject): + def __init__(self, server, port, retry_master): + self.server = server + self.port = port + self.retry_master = retry_master + self.controller_db = ControllerDB() + + async def _do(self): + try: + subscriber = Subscriber("devices", + self.controller_db.sync_struct_init) + while True: + try: + def set_host_filter(): + s = subscriber.writer.get_extra_info("socket") + localhost = s.getsockname()[0] + self.controller_db.set_host_filter(localhost) + await subscriber.connect(self.server, self.port, + set_host_filter) + try: + await asyncio.wait_for(subscriber.receive_task, None) + finally: + await subscriber.close() + except (ConnectionAbortedError, ConnectionError, + ConnectionRefusedError, ConnectionResetError) as e: + logger.warning("Connection to master failed (%s: %s)", + e.__class__.__name__, str(e)) + else: + logger.warning("Connection to master lost") + logger.warning("Retrying in %.1f seconds", self.retry_master) + await asyncio.sleep(self.retry_master) + except asyncio.CancelledError: + pass + finally: + await self.controller_db.current_controllers.shutdown() + + def retry_now(self, k): + """If a controller is disabled and pending retry, perform that retry + now.""" + self.controller_db.current_controllers.active[k].retry_now.notify() diff --git a/artiq/frontend/artiq_ctlmgr.py b/artiq/frontend/artiq_ctlmgr.py index d6e7a299c..d3091dfdd 100755 --- a/artiq/frontend/artiq_ctlmgr.py +++ b/artiq/frontend/artiq_ctlmgr.py @@ -5,249 +5,13 @@ import atexit import argparse import os import logging -import subprocess -import shlex -import socket import platform -from artiq.protocols.sync_struct import Subscriber -from artiq.protocols.pc_rpc import AsyncioClient, Server -from artiq.protocols.logging import (LogForwarder, LogParser, - SourceFilter) -from artiq.tools import * - - -logger = logging.getLogger(__name__) - - -class Controller: - def __init__(self, name, ddb_entry): - self.name = name - self.command = ddb_entry["command"] - self.retry_timer = ddb_entry.get("retry_timer", 5) - self.retry_timer_backoff = ddb_entry.get("retry_timer_backoff", 1.1) - - self.host = ddb_entry["host"] - self.port = ddb_entry["port"] - self.ping_timer = ddb_entry.get("ping_timer", 30) - self.ping_timeout = ddb_entry.get("ping_timeout", 30) - self.term_timeout = ddb_entry.get("term_timeout", 30) - - self.retry_timer_cur = self.retry_timer - self.retry_now = Condition() - self.process = None - self.launch_task = asyncio.Task(self.launcher()) - - async def end(self): - self.launch_task.cancel() - await asyncio.wait_for(self.launch_task, None) - - async def _call_controller(self, method): - remote = AsyncioClient() - await remote.connect_rpc(self.host, self.port, None) - try: - targets, _ = remote.get_rpc_id() - remote.select_rpc_target(targets[0]) - r = await getattr(remote, method)() - finally: - remote.close_rpc() - return r - - async def _ping(self): - try: - ok = await asyncio.wait_for(self._call_controller("ping"), - self.ping_timeout) - if ok: - self.retry_timer_cur = self.retry_timer - return ok - except: - return False - - async def _wait_and_ping(self): - while True: - try: - await asyncio.wait_for(self.process.wait(), - self.ping_timer) - except asyncio.TimeoutError: - logger.debug("pinging controller %s", self.name) - ok = await self._ping() - if not ok: - logger.warning("Controller %s ping failed", self.name) - await self._terminate() - return - else: - break - - def _get_log_source(self): - return "controller({})".format(self.name) - - async def launcher(self): - try: - while True: - logger.info("Starting controller %s with command: %s", - self.name, self.command) - try: - self.process = await asyncio.create_subprocess_exec( - *shlex.split(self.command), - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - asyncio.ensure_future( - LogParser(self._get_log_source).stream_task( - self.process.stdout)) - asyncio.ensure_future( - LogParser(self._get_log_source).stream_task( - self.process.stderr)) - await self._wait_and_ping() - except FileNotFoundError: - logger.warning("Controller %s failed to start", self.name) - else: - logger.warning("Controller %s exited", self.name) - logger.warning("Restarting in %.1f seconds", - self.retry_timer_cur) - try: - await asyncio.wait_for(self.retry_now.wait(), - self.retry_timer_cur) - except asyncio.TimeoutError: - pass - self.retry_timer_cur *= self.retry_timer_backoff - except asyncio.CancelledError: - await self._terminate() - - async def _terminate(self): - logger.info("Terminating controller %s", self.name) - if self.process is not None and self.process.returncode is None: - try: - await asyncio.wait_for(self._call_controller("terminate"), - self.term_timeout) - except: - logger.warning("Controller %s did not respond to terminate " - "command, killing", self.name) - try: - self.process.kill() - except ProcessLookupError: - pass - try: - await asyncio.wait_for(self.process.wait(), - self.term_timeout) - except: - logger.warning("Controller %s failed to exit, killing", - self.name) - try: - self.process.kill() - except ProcessLookupError: - pass - await self.process.wait() - logger.debug("Controller %s terminated", self.name) - - -def get_ip_addresses(host): - try: - addrinfo = socket.getaddrinfo(host, None) - except: - return set() - return {info[4][0] for info in addrinfo} - - -class Controllers: - def __init__(self): - self.host_filter = None - self.active_or_queued = set() - self.queue = asyncio.Queue() - self.active = dict() - self.process_task = asyncio.Task(self._process()) - - async def _process(self): - while True: - action, param = await self.queue.get() - if action == "set": - k, ddb_entry = param - if k in self.active: - await self.active[k].end() - self.active[k] = Controller(k, ddb_entry) - elif action == "del": - await self.active[param].end() - del self.active[param] - else: - raise ValueError - - def __setitem__(self, k, v): - if (isinstance(v, dict) and v["type"] == "controller" - and self.host_filter in get_ip_addresses(v["host"])): - v["command"] = v["command"].format(name=k, - bind=self.host_filter, - port=v["port"]) - self.queue.put_nowait(("set", (k, v))) - self.active_or_queued.add(k) - - def __delitem__(self, k): - if k in self.active_or_queued: - self.queue.put_nowait(("del", k)) - self.active_or_queued.remove(k) - - def delete_all(self): - for name in set(self.active_or_queued): - del self[name] - - async def shutdown(self): - self.process_task.cancel() - for c in self.active.values(): - await c.end() - - -class ControllerDB: - def __init__(self): - self.current_controllers = Controllers() - - def set_host_filter(self, host_filter): - self.current_controllers.host_filter = host_filter - - def sync_struct_init(self, init): - if self.current_controllers is not None: - self.current_controllers.delete_all() - for k, v in init.items(): - self.current_controllers[k] = v - return self.current_controllers - - -class ControllerManager(TaskObject): - def __init__(self, server, port, retry_master): - self.server = server - self.port = port - self.retry_master = retry_master - self.controller_db = ControllerDB() - - async def _do(self): - try: - subscriber = Subscriber("devices", - self.controller_db.sync_struct_init) - while True: - try: - def set_host_filter(): - s = subscriber.writer.get_extra_info("socket") - localhost = s.getsockname()[0] - self.controller_db.set_host_filter(localhost) - await subscriber.connect(self.server, self.port, - set_host_filter) - try: - await asyncio.wait_for(subscriber.receive_task, None) - finally: - await subscriber.close() - except (ConnectionAbortedError, ConnectionError, - ConnectionRefusedError, ConnectionResetError) as e: - logger.warning("Connection to master failed (%s: %s)", - e.__class__.__name__, str(e)) - else: - logger.warning("Connection to master lost") - logger.warning("Retrying in %.1f seconds", self.retry_master) - await asyncio.sleep(self.retry_master) - except asyncio.CancelledError: - pass - finally: - await self.controller_db.current_controllers.shutdown() - - def retry_now(self, k): - """If a controller is disabled and pending retry, perform that retry - now.""" - self.controller_db.current_controllers.active[k].retry_now.notify() +from artiq.protocols.pc_rpc import Server +from artiq.protocols.logging import LogForwarder, SourceFilter +from artiq.tools import (simple_network_args, atexit_register_coroutine, + bind_address_from_args) +from artiq.devices.ctlmgr import ControllerManager def get_argparser(): @@ -280,7 +44,8 @@ def main(): root_logger = logging.getLogger() root_logger.setLevel(logging.NOTSET) - source_adder = SourceFilter(logging.WARNING + args.quiet*10 - args.verbose*10, + source_adder = SourceFilter(logging.WARNING + + args.quiet*10 - args.verbose*10, "ctlmgr({})".format(platform.node())) console_handler = logging.StreamHandler() console_handler.setFormatter(logging.Formatter(