diff --git a/artiq/devices/ctlmgr.py b/artiq/devices/ctlmgr.py new file mode 100644 index 000000000..0a09c7cb8 --- /dev/null +++ b/artiq/devices/ctlmgr.py @@ -0,0 +1,255 @@ +import asyncio +import logging +import subprocess +import shlex +import socket + +from artiq.protocols.sync_struct import Subscriber +from artiq.protocols.pc_rpc import AsyncioClient +from artiq.protocols.logging import parse_log_message, log_with_name +from artiq.tools import Condition, TaskObject + + +logger = logging.getLogger(__name__) + + +class Controller: + def __init__(self, name, ddb_entry): + self.name = name + self.command = ddb_entry["command"] + self.retry_timer = ddb_entry.get("retry_timer", 5) + self.retry_timer_backoff = ddb_entry.get("retry_timer_backoff", 1.1) + + self.host = ddb_entry["host"] + self.port = ddb_entry["port"] + self.ping_timer = ddb_entry.get("ping_timer", 30) + self.ping_timeout = ddb_entry.get("ping_timeout", 30) + self.term_timeout = ddb_entry.get("term_timeout", 30) + + self.retry_timer_cur = self.retry_timer + self.retry_now = Condition() + self.process = None + self.launch_task = asyncio.Task(self.launcher()) + + async def end(self): + self.launch_task.cancel() + await asyncio.wait_for(self.launch_task, None) + + async def call(self, method, *args, **kwargs): + remote = AsyncioClient() + await remote.connect_rpc(self.host, self.port, None) + try: + targets, _ = remote.get_rpc_id() + remote.select_rpc_target(targets[0]) + r = await getattr(remote, method)(*args, **kwargs) + finally: + remote.close_rpc() + return r + + async def _ping(self): + try: + ok = await asyncio.wait_for(self.call("ping"), + self.ping_timeout) + if ok: + self.retry_timer_cur = self.retry_timer + return ok + except: + return False + + async def _wait_and_ping(self): + while True: + try: + await asyncio.wait_for(self.process.wait(), + self.ping_timer) + except asyncio.TimeoutError: + logger.debug("pinging controller %s", self.name) + ok = await self._ping() + if not ok: + logger.warning("Controller %s ping failed", self.name) + await self._terminate() + return + else: + break + + async def forward_logs(self, stream): + source = "controller({})".format(self.name) + while True: + try: + entry = (await stream.readline()) + if not entry: + break + entry = entry[:-1] + level, name, message = parse_log_message(entry.decode()) + log_with_name(name, level, message, extra={"source": source}) + except: + logger.debug("exception in log forwarding", exc_info=True) + break + logger.debug("stopped log forwarding of stream %s of %s", + stream, self.name) + + async def launcher(self): + try: + while True: + logger.info("Starting controller %s with command: %s", + self.name, self.command) + try: + self.process = await asyncio.create_subprocess_exec( + *shlex.split(self.command), + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + asyncio.ensure_future(self.forward_logs( + self.process.stdout)) + asyncio.ensure_future(self.forward_logs( + self.process.stderr)) + await self._wait_and_ping() + except FileNotFoundError: + logger.warning("Controller %s failed to start", self.name) + else: + logger.warning("Controller %s exited", self.name) + logger.warning("Restarting in %.1f seconds", + self.retry_timer_cur) + try: + await asyncio.wait_for(self.retry_now.wait(), + self.retry_timer_cur) + except asyncio.TimeoutError: + pass + self.retry_timer_cur *= self.retry_timer_backoff + except asyncio.CancelledError: + await self._terminate() + + async def _terminate(self): + logger.info("Terminating controller %s", self.name) + if self.process is not None and self.process.returncode is None: + try: + await asyncio.wait_for(self.call("terminate"), + self.term_timeout) + except: + logger.warning("Controller %s did not respond to terminate " + "command, killing", self.name) + try: + self.process.kill() + except ProcessLookupError: + pass + try: + await asyncio.wait_for(self.process.wait(), + self.term_timeout) + except: + logger.warning("Controller %s failed to exit, killing", + self.name) + try: + self.process.kill() + except ProcessLookupError: + pass + await self.process.wait() + logger.debug("Controller %s terminated", self.name) + + +def get_ip_addresses(host): + try: + addrinfo = socket.getaddrinfo(host, None) + except: + return set() + return {info[4][0] for info in addrinfo} + + +class Controllers: + def __init__(self): + self.host_filter = None + self.active_or_queued = set() + self.queue = asyncio.Queue() + self.active = dict() + self.process_task = asyncio.Task(self._process()) + + async def _process(self): + while True: + action, param = await self.queue.get() + if action == "set": + k, ddb_entry = param + if k in self.active: + await self.active[k].end() + self.active[k] = Controller(k, ddb_entry) + elif action == "del": + await self.active[param].end() + del self.active[param] + self.queue.task_done() + if action not in ("set", "del"): + raise ValueError + + def __setitem__(self, k, v): + if (isinstance(v, dict) and v["type"] == "controller" and + self.host_filter in get_ip_addresses(v["host"])): + v["command"] = v["command"].format(name=k, + bind=self.host_filter, + port=v["port"]) + self.queue.put_nowait(("set", (k, v))) + self.active_or_queued.add(k) + + def __delitem__(self, k): + if k in self.active_or_queued: + self.queue.put_nowait(("del", k)) + self.active_or_queued.remove(k) + + def delete_all(self): + for name in set(self.active_or_queued): + del self[name] + + async def shutdown(self): + self.process_task.cancel() + for c in self.active.values(): + await c.end() + + +class ControllerDB: + def __init__(self): + self.current_controllers = Controllers() + + def set_host_filter(self, host_filter): + self.current_controllers.host_filter = host_filter + + def sync_struct_init(self, init): + if self.current_controllers is not None: + self.current_controllers.delete_all() + for k, v in init.items(): + self.current_controllers[k] = v + return self.current_controllers + + +class ControllerManager(TaskObject): + def __init__(self, server, port, retry_master): + self.server = server + self.port = port + self.retry_master = retry_master + self.controller_db = ControllerDB() + + async def _do(self): + try: + subscriber = Subscriber("devices", + self.controller_db.sync_struct_init) + while True: + try: + def set_host_filter(): + s = subscriber.writer.get_extra_info("socket") + localhost = s.getsockname()[0] + self.controller_db.set_host_filter(localhost) + await subscriber.connect(self.server, self.port, + set_host_filter) + try: + await asyncio.wait_for(subscriber.receive_task, None) + finally: + await subscriber.close() + except (ConnectionAbortedError, ConnectionError, + ConnectionRefusedError, ConnectionResetError) as e: + logger.warning("Connection to master failed (%s: %s)", + e.__class__.__name__, str(e)) + else: + logger.warning("Connection to master lost") + logger.warning("Retrying in %.1f seconds", self.retry_master) + await asyncio.sleep(self.retry_master) + except asyncio.CancelledError: + pass + finally: + await self.controller_db.current_controllers.shutdown() + + def retry_now(self, k): + """If a controller is disabled and pending retry, perform that retry + now.""" + self.controller_db.current_controllers.active[k].retry_now.notify() diff --git a/artiq/frontend/artiq_ctlmgr.py b/artiq/frontend/artiq_ctlmgr.py index d6e7a299c..d3091dfdd 100755 --- a/artiq/frontend/artiq_ctlmgr.py +++ b/artiq/frontend/artiq_ctlmgr.py @@ -5,249 +5,13 @@ import atexit import argparse import os import logging -import subprocess -import shlex -import socket import platform -from artiq.protocols.sync_struct import Subscriber -from artiq.protocols.pc_rpc import AsyncioClient, Server -from artiq.protocols.logging import (LogForwarder, LogParser, - SourceFilter) -from artiq.tools import * - - -logger = logging.getLogger(__name__) - - -class Controller: - def __init__(self, name, ddb_entry): - self.name = name - self.command = ddb_entry["command"] - self.retry_timer = ddb_entry.get("retry_timer", 5) - self.retry_timer_backoff = ddb_entry.get("retry_timer_backoff", 1.1) - - self.host = ddb_entry["host"] - self.port = ddb_entry["port"] - self.ping_timer = ddb_entry.get("ping_timer", 30) - self.ping_timeout = ddb_entry.get("ping_timeout", 30) - self.term_timeout = ddb_entry.get("term_timeout", 30) - - self.retry_timer_cur = self.retry_timer - self.retry_now = Condition() - self.process = None - self.launch_task = asyncio.Task(self.launcher()) - - async def end(self): - self.launch_task.cancel() - await asyncio.wait_for(self.launch_task, None) - - async def _call_controller(self, method): - remote = AsyncioClient() - await remote.connect_rpc(self.host, self.port, None) - try: - targets, _ = remote.get_rpc_id() - remote.select_rpc_target(targets[0]) - r = await getattr(remote, method)() - finally: - remote.close_rpc() - return r - - async def _ping(self): - try: - ok = await asyncio.wait_for(self._call_controller("ping"), - self.ping_timeout) - if ok: - self.retry_timer_cur = self.retry_timer - return ok - except: - return False - - async def _wait_and_ping(self): - while True: - try: - await asyncio.wait_for(self.process.wait(), - self.ping_timer) - except asyncio.TimeoutError: - logger.debug("pinging controller %s", self.name) - ok = await self._ping() - if not ok: - logger.warning("Controller %s ping failed", self.name) - await self._terminate() - return - else: - break - - def _get_log_source(self): - return "controller({})".format(self.name) - - async def launcher(self): - try: - while True: - logger.info("Starting controller %s with command: %s", - self.name, self.command) - try: - self.process = await asyncio.create_subprocess_exec( - *shlex.split(self.command), - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - asyncio.ensure_future( - LogParser(self._get_log_source).stream_task( - self.process.stdout)) - asyncio.ensure_future( - LogParser(self._get_log_source).stream_task( - self.process.stderr)) - await self._wait_and_ping() - except FileNotFoundError: - logger.warning("Controller %s failed to start", self.name) - else: - logger.warning("Controller %s exited", self.name) - logger.warning("Restarting in %.1f seconds", - self.retry_timer_cur) - try: - await asyncio.wait_for(self.retry_now.wait(), - self.retry_timer_cur) - except asyncio.TimeoutError: - pass - self.retry_timer_cur *= self.retry_timer_backoff - except asyncio.CancelledError: - await self._terminate() - - async def _terminate(self): - logger.info("Terminating controller %s", self.name) - if self.process is not None and self.process.returncode is None: - try: - await asyncio.wait_for(self._call_controller("terminate"), - self.term_timeout) - except: - logger.warning("Controller %s did not respond to terminate " - "command, killing", self.name) - try: - self.process.kill() - except ProcessLookupError: - pass - try: - await asyncio.wait_for(self.process.wait(), - self.term_timeout) - except: - logger.warning("Controller %s failed to exit, killing", - self.name) - try: - self.process.kill() - except ProcessLookupError: - pass - await self.process.wait() - logger.debug("Controller %s terminated", self.name) - - -def get_ip_addresses(host): - try: - addrinfo = socket.getaddrinfo(host, None) - except: - return set() - return {info[4][0] for info in addrinfo} - - -class Controllers: - def __init__(self): - self.host_filter = None - self.active_or_queued = set() - self.queue = asyncio.Queue() - self.active = dict() - self.process_task = asyncio.Task(self._process()) - - async def _process(self): - while True: - action, param = await self.queue.get() - if action == "set": - k, ddb_entry = param - if k in self.active: - await self.active[k].end() - self.active[k] = Controller(k, ddb_entry) - elif action == "del": - await self.active[param].end() - del self.active[param] - else: - raise ValueError - - def __setitem__(self, k, v): - if (isinstance(v, dict) and v["type"] == "controller" - and self.host_filter in get_ip_addresses(v["host"])): - v["command"] = v["command"].format(name=k, - bind=self.host_filter, - port=v["port"]) - self.queue.put_nowait(("set", (k, v))) - self.active_or_queued.add(k) - - def __delitem__(self, k): - if k in self.active_or_queued: - self.queue.put_nowait(("del", k)) - self.active_or_queued.remove(k) - - def delete_all(self): - for name in set(self.active_or_queued): - del self[name] - - async def shutdown(self): - self.process_task.cancel() - for c in self.active.values(): - await c.end() - - -class ControllerDB: - def __init__(self): - self.current_controllers = Controllers() - - def set_host_filter(self, host_filter): - self.current_controllers.host_filter = host_filter - - def sync_struct_init(self, init): - if self.current_controllers is not None: - self.current_controllers.delete_all() - for k, v in init.items(): - self.current_controllers[k] = v - return self.current_controllers - - -class ControllerManager(TaskObject): - def __init__(self, server, port, retry_master): - self.server = server - self.port = port - self.retry_master = retry_master - self.controller_db = ControllerDB() - - async def _do(self): - try: - subscriber = Subscriber("devices", - self.controller_db.sync_struct_init) - while True: - try: - def set_host_filter(): - s = subscriber.writer.get_extra_info("socket") - localhost = s.getsockname()[0] - self.controller_db.set_host_filter(localhost) - await subscriber.connect(self.server, self.port, - set_host_filter) - try: - await asyncio.wait_for(subscriber.receive_task, None) - finally: - await subscriber.close() - except (ConnectionAbortedError, ConnectionError, - ConnectionRefusedError, ConnectionResetError) as e: - logger.warning("Connection to master failed (%s: %s)", - e.__class__.__name__, str(e)) - else: - logger.warning("Connection to master lost") - logger.warning("Retrying in %.1f seconds", self.retry_master) - await asyncio.sleep(self.retry_master) - except asyncio.CancelledError: - pass - finally: - await self.controller_db.current_controllers.shutdown() - - def retry_now(self, k): - """If a controller is disabled and pending retry, perform that retry - now.""" - self.controller_db.current_controllers.active[k].retry_now.notify() +from artiq.protocols.pc_rpc import Server +from artiq.protocols.logging import LogForwarder, SourceFilter +from artiq.tools import (simple_network_args, atexit_register_coroutine, + bind_address_from_args) +from artiq.devices.ctlmgr import ControllerManager def get_argparser(): @@ -280,7 +44,8 @@ def main(): root_logger = logging.getLogger() root_logger.setLevel(logging.NOTSET) - source_adder = SourceFilter(logging.WARNING + args.quiet*10 - args.verbose*10, + source_adder = SourceFilter(logging.WARNING + + args.quiet*10 - args.verbose*10, "ctlmgr({})".format(platform.node())) console_handler = logging.StreamHandler() console_handler.setFormatter(logging.Formatter( diff --git a/artiq/test/ctlmgr.py b/artiq/test/ctlmgr.py new file mode 100644 index 000000000..f45f8ace7 --- /dev/null +++ b/artiq/test/ctlmgr.py @@ -0,0 +1,70 @@ +import os +import unittest +import logging +import asyncio + +from artiq.devices.ctlmgr import Controllers +from artiq.protocols.pc_rpc import AsyncioClient + +logger = logging.getLogger(__name__) + + +class ControllerCase(unittest.TestCase): + def setUp(self): + if os.name == "nt": + self.loop = asyncio.ProactorEventLoop() + else: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.addCleanup(self.loop.close) + + self.controllers = Controllers() + self.controllers.host_filter = "::1" + self.addCleanup( + lambda: self.loop.run_until_complete(self.controllers.shutdown())) + + async def start(self, name, entry): + self.controllers[name] = entry + await self.controllers.queue.join() + await self.wait_for_ping(entry["host"], entry["port"]) + + async def get_client(self, host, port): + remote = AsyncioClient() + await remote.connect_rpc(host, port, None) + targets, _ = remote.get_rpc_id() + remote.select_rpc_target(targets[0]) + self.addCleanup(remote.close_rpc) + return remote + + async def wait_for_ping(self, host, port, retries=5, timeout=2): + dt = timeout/retries + while timeout > 0: + try: + remote = await self.get_client(host, port) + ok = await asyncio.wait_for(remote.ping(), dt) + if not ok: + raise ValueError("unexcepted ping() response from " + "controller: `{}`".format(ok)) + return ok + except asyncio.TimeoutError: + timeout -= dt + except (ConnectionAbortedError, ConnectionError, + ConnectionRefusedError, ConnectionResetError): + await asyncio.sleep(dt) + timeout -= dt + raise asyncio.TimeoutError + + def test_start_ping_stop_controller(self): + entry = { + "type": "controller", + "host": "::1", + "port": 3253, + "command": "lda_controller -p {port} --bind {bind} " + "--no-localhost-bind --simulation", + } + async def test(): + await self.start("lda_sim", entry) + remote = await self.get_client(entry["host"], entry["port"]) + await remote.close() + + self.loop.run_until_complete(test()) diff --git a/artiq/test/hardware_testbench.py b/artiq/test/hardware_testbench.py index aa6c9d39e..d4a244c96 100644 --- a/artiq/test/hardware_testbench.py +++ b/artiq/test/hardware_testbench.py @@ -5,12 +5,13 @@ import os import sys import unittest import logging +import subprocess +import shlex +import time -from artiq.experiment import * from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.worker_db import DeviceManager, DatasetManager from artiq.coredevice.core import CompileError -from artiq.protocols import pyon from artiq.frontend.artiq_run import DummyScheduler @@ -18,28 +19,49 @@ artiq_root = os.getenv("ARTIQ_ROOT") logger = logging.getLogger(__name__) -def get_from_ddb(*path, default="skip"): - if not artiq_root: - raise unittest.SkipTest("no ARTIQ_ROOT") - v = pyon.load_file(os.path.join(artiq_root, "device_db.pyon")) - try: - for p in path: - v = v[p] - return v.read - except KeyError: - if default == "skip": - raise unittest.SkipTest("device db path {} not found".format(path)) - else: - return default +@unittest.skipUnless(artiq_root, "no ARTIQ_ROOT") +class ControllerCase(unittest.TestCase): + def setUp(self): + self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon")) + self.device_mgr = DeviceManager(self.device_db) + self.controllers = {} + + def tearDown(self): + self.device_mgr.close_devices() + for name in list(self.controllers): + self.stop_controller(name) + + def start_controller(self, name, sleep=1): + try: + entry = self.device_db.get(name) + except KeyError: + raise unittest.SkipTest( + "controller `{}` not found".format(name)) + entry["command"] = entry["command"].format( + name=name, bind=entry["host"], port=entry["port"]) + proc = subprocess.Popen(shlex.split(entry["command"])) + self.controllers[name] = entry, proc + time.sleep(sleep) + + def stop_controller(self, name, default_timeout=1): + entry, proc = self.controllers[name] + t = entry.get("term_timeout", default_timeout) + try: + proc.wait(t) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(t) + del self.controllers[name] @unittest.skipUnless(artiq_root, "no ARTIQ_ROOT") class ExperimentCase(unittest.TestCase): def setUp(self): self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon")) - self.dataset_db = DatasetDB(os.path.join(artiq_root, "dataset_db.pyon")) - self.device_mgr = DeviceManager(self.device_db, - virtual_devices={"scheduler": DummyScheduler()}) + self.dataset_db = DatasetDB( + os.path.join(artiq_root, "dataset_db.pyon")) + self.device_mgr = DeviceManager( + self.device_db, virtual_devices={"scheduler": DummyScheduler()}) self.dataset_mgr = DatasetManager(self.dataset_db) def create(self, cls, **kwargs): diff --git a/artiq/test/lda.py b/artiq/test/lda.py index 5231b7964..2307f4dd2 100644 --- a/artiq/test/lda.py +++ b/artiq/test/lda.py @@ -1,8 +1,8 @@ import unittest -from artiq.devices.lda.driver import Lda, Ldasim +from artiq.devices.lda.driver import Ldasim from artiq.language.units import dB -from artiq.test.hardware_testbench import get_from_ddb +from artiq.test.hardware_testbench import ControllerCase class GenericLdaTest: @@ -13,14 +13,15 @@ class GenericLdaTest: for i in test_vector: with self.subTest(i=i): self.cont.set_attenuation(i) - self.assertEqual(i, self.cont.get_attenuation()) + j = self.cont.get_attenuation() + self.assertEqual(i, j) -class TestLda(GenericLdaTest, unittest.TestCase): +class TestLda(ControllerCase, GenericLdaTest): def setUp(self): - lda_serial = get_from_ddb("lda", "device") - lda_product = get_from_ddb("lda", "product") - self.cont = Lda(serial=lda_serial, product=lda_product) + ControllerCase.setUp(self) + self.start_controller("lda") + self.cont = self.device_mgr.get("lda") class TestLdaSim(GenericLdaTest, unittest.TestCase): diff --git a/artiq/test/novatech409b.py b/artiq/test/novatech409b.py index 314b1cb19..8c6327ce4 100644 --- a/artiq/test/novatech409b.py +++ b/artiq/test/novatech409b.py @@ -1,7 +1,7 @@ import unittest from artiq.devices.novatech409b.driver import Novatech409B -from artiq.test.hardware_testbench import get_from_ddb +from artiq.test.hardware_testbench import ControllerCase class GenericNovatech409BTest: @@ -20,10 +20,11 @@ class GenericNovatech409BTest: self.assertEqual(r[0:23], "00989680 2000 01F5 0000") -class TestNovatech409B(GenericNovatech409BTest, unittest.TestCase): +class TestNovatech409B(GenericNovatech409BTest, ControllerCase): def setUp(self): - novatech409b_device = get_from_ddb("novatech409b", "device") - self.driver = Novatech409B(novatech409b_device) + ControllerCase.setUp(self) + self.start_controller("novatech409b") + self.driver = self.device_mgr.get("novatech409b") class TestNovatech409BSim(GenericNovatech409BTest, unittest.TestCase): diff --git a/artiq/test/thorlabs_tcube.py b/artiq/test/thorlabs_tcube.py index 0ecb362e8..733a0e602 100644 --- a/artiq/test/thorlabs_tcube.py +++ b/artiq/test/thorlabs_tcube.py @@ -3,7 +3,7 @@ import time from artiq.devices.thorlabs_tcube.driver import Tdc, Tpz, TdcSim, TpzSim from artiq.language.units import V -from artiq.test.hardware_testbench import get_from_ddb +from artiq.test.hardware_testbench import ControllerCase class GenericTdcTest: @@ -131,10 +131,11 @@ class GenericTpzTest: self.assertEqual(test_vector, self.cont.get_tpz_io_settings()) -class TestTdc(unittest.TestCase, GenericTdcTest): +class TestTdc(ControllerCase, GenericTdcTest): def setUp(self): - tdc_serial = get_from_ddb("tdc", "device") - self.cont = Tdc(serial_dev=tdc_serial) + ControllerCase.setUp(self) + self.start_controller("tdc") + self.cont = self.device_mgr.get("tdc") class TestTdcSim(unittest.TestCase, GenericTdcTest): @@ -142,10 +143,11 @@ class TestTdcSim(unittest.TestCase, GenericTdcTest): self.cont = TdcSim() -class TestTpz(unittest.TestCase, GenericTpzTest): +class TestTpz(ControllerCase, GenericTpzTest): def setUp(self): - tpz_serial = get_from_ddb("tpz", "device") - self.cont = Tpz(serial_dev=tpz_serial) + ControllerCase.setUp(self) + self.start_controller("tpz") + self.cont = self.device_mgr.get("tpz") class TestTpzSim(unittest.TestCase, GenericTpzTest):