mirror of https://github.com/m-labs/artiq.git
Merge branch 'testbench-controllers'
closes #69 * testbench-controllers: test.ctlmgr: add basic test tooling hardware_testbench: fix timeout handling hardware_testbench: use plain subprocess to start controllers hardware_testbench: run Crontrollers loop in thread, not the test hardware_testbench: run controllers lda: test tweaks artiq_ctlmgr: refactor into artiq.devices.ctlmgr
This commit is contained in:
commit
627221a5cd
|
@ -0,0 +1,255 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import shlex
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from artiq.protocols.sync_struct import Subscriber
|
||||||
|
from artiq.protocols.pc_rpc import AsyncioClient
|
||||||
|
from artiq.protocols.logging import parse_log_message, log_with_name
|
||||||
|
from artiq.tools import Condition, TaskObject
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Controller:
|
||||||
|
def __init__(self, name, ddb_entry):
|
||||||
|
self.name = name
|
||||||
|
self.command = ddb_entry["command"]
|
||||||
|
self.retry_timer = ddb_entry.get("retry_timer", 5)
|
||||||
|
self.retry_timer_backoff = ddb_entry.get("retry_timer_backoff", 1.1)
|
||||||
|
|
||||||
|
self.host = ddb_entry["host"]
|
||||||
|
self.port = ddb_entry["port"]
|
||||||
|
self.ping_timer = ddb_entry.get("ping_timer", 30)
|
||||||
|
self.ping_timeout = ddb_entry.get("ping_timeout", 30)
|
||||||
|
self.term_timeout = ddb_entry.get("term_timeout", 30)
|
||||||
|
|
||||||
|
self.retry_timer_cur = self.retry_timer
|
||||||
|
self.retry_now = Condition()
|
||||||
|
self.process = None
|
||||||
|
self.launch_task = asyncio.Task(self.launcher())
|
||||||
|
|
||||||
|
async def end(self):
|
||||||
|
self.launch_task.cancel()
|
||||||
|
await asyncio.wait_for(self.launch_task, None)
|
||||||
|
|
||||||
|
async def call(self, method, *args, **kwargs):
|
||||||
|
remote = AsyncioClient()
|
||||||
|
await remote.connect_rpc(self.host, self.port, None)
|
||||||
|
try:
|
||||||
|
targets, _ = remote.get_rpc_id()
|
||||||
|
remote.select_rpc_target(targets[0])
|
||||||
|
r = await getattr(remote, method)(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
remote.close_rpc()
|
||||||
|
return r
|
||||||
|
|
||||||
|
async def _ping(self):
|
||||||
|
try:
|
||||||
|
ok = await asyncio.wait_for(self.call("ping"),
|
||||||
|
self.ping_timeout)
|
||||||
|
if ok:
|
||||||
|
self.retry_timer_cur = self.retry_timer
|
||||||
|
return ok
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _wait_and_ping(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.process.wait(),
|
||||||
|
self.ping_timer)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.debug("pinging controller %s", self.name)
|
||||||
|
ok = await self._ping()
|
||||||
|
if not ok:
|
||||||
|
logger.warning("Controller %s ping failed", self.name)
|
||||||
|
await self._terminate()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def forward_logs(self, stream):
|
||||||
|
source = "controller({})".format(self.name)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
entry = (await stream.readline())
|
||||||
|
if not entry:
|
||||||
|
break
|
||||||
|
entry = entry[:-1]
|
||||||
|
level, name, message = parse_log_message(entry.decode())
|
||||||
|
log_with_name(name, level, message, extra={"source": source})
|
||||||
|
except:
|
||||||
|
logger.debug("exception in log forwarding", exc_info=True)
|
||||||
|
break
|
||||||
|
logger.debug("stopped log forwarding of stream %s of %s",
|
||||||
|
stream, self.name)
|
||||||
|
|
||||||
|
async def launcher(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
logger.info("Starting controller %s with command: %s",
|
||||||
|
self.name, self.command)
|
||||||
|
try:
|
||||||
|
self.process = await asyncio.create_subprocess_exec(
|
||||||
|
*shlex.split(self.command),
|
||||||
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
asyncio.ensure_future(self.forward_logs(
|
||||||
|
self.process.stdout))
|
||||||
|
asyncio.ensure_future(self.forward_logs(
|
||||||
|
self.process.stderr))
|
||||||
|
await self._wait_and_ping()
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("Controller %s failed to start", self.name)
|
||||||
|
else:
|
||||||
|
logger.warning("Controller %s exited", self.name)
|
||||||
|
logger.warning("Restarting in %.1f seconds",
|
||||||
|
self.retry_timer_cur)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.retry_now.wait(),
|
||||||
|
self.retry_timer_cur)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
self.retry_timer_cur *= self.retry_timer_backoff
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await self._terminate()
|
||||||
|
|
||||||
|
async def _terminate(self):
|
||||||
|
logger.info("Terminating controller %s", self.name)
|
||||||
|
if self.process is not None and self.process.returncode is None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.call("terminate"),
|
||||||
|
self.term_timeout)
|
||||||
|
except:
|
||||||
|
logger.warning("Controller %s did not respond to terminate "
|
||||||
|
"command, killing", self.name)
|
||||||
|
try:
|
||||||
|
self.process.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.process.wait(),
|
||||||
|
self.term_timeout)
|
||||||
|
except:
|
||||||
|
logger.warning("Controller %s failed to exit, killing",
|
||||||
|
self.name)
|
||||||
|
try:
|
||||||
|
self.process.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
await self.process.wait()
|
||||||
|
logger.debug("Controller %s terminated", self.name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ip_addresses(host):
|
||||||
|
try:
|
||||||
|
addrinfo = socket.getaddrinfo(host, None)
|
||||||
|
except:
|
||||||
|
return set()
|
||||||
|
return {info[4][0] for info in addrinfo}
|
||||||
|
|
||||||
|
|
||||||
|
class Controllers:
|
||||||
|
def __init__(self):
|
||||||
|
self.host_filter = None
|
||||||
|
self.active_or_queued = set()
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
self.active = dict()
|
||||||
|
self.process_task = asyncio.Task(self._process())
|
||||||
|
|
||||||
|
async def _process(self):
|
||||||
|
while True:
|
||||||
|
action, param = await self.queue.get()
|
||||||
|
if action == "set":
|
||||||
|
k, ddb_entry = param
|
||||||
|
if k in self.active:
|
||||||
|
await self.active[k].end()
|
||||||
|
self.active[k] = Controller(k, ddb_entry)
|
||||||
|
elif action == "del":
|
||||||
|
await self.active[param].end()
|
||||||
|
del self.active[param]
|
||||||
|
self.queue.task_done()
|
||||||
|
if action not in ("set", "del"):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def __setitem__(self, k, v):
|
||||||
|
if (isinstance(v, dict) and v["type"] == "controller" and
|
||||||
|
self.host_filter in get_ip_addresses(v["host"])):
|
||||||
|
v["command"] = v["command"].format(name=k,
|
||||||
|
bind=self.host_filter,
|
||||||
|
port=v["port"])
|
||||||
|
self.queue.put_nowait(("set", (k, v)))
|
||||||
|
self.active_or_queued.add(k)
|
||||||
|
|
||||||
|
def __delitem__(self, k):
|
||||||
|
if k in self.active_or_queued:
|
||||||
|
self.queue.put_nowait(("del", k))
|
||||||
|
self.active_or_queued.remove(k)
|
||||||
|
|
||||||
|
def delete_all(self):
|
||||||
|
for name in set(self.active_or_queued):
|
||||||
|
del self[name]
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
self.process_task.cancel()
|
||||||
|
for c in self.active.values():
|
||||||
|
await c.end()
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerDB:
|
||||||
|
def __init__(self):
|
||||||
|
self.current_controllers = Controllers()
|
||||||
|
|
||||||
|
def set_host_filter(self, host_filter):
|
||||||
|
self.current_controllers.host_filter = host_filter
|
||||||
|
|
||||||
|
def sync_struct_init(self, init):
|
||||||
|
if self.current_controllers is not None:
|
||||||
|
self.current_controllers.delete_all()
|
||||||
|
for k, v in init.items():
|
||||||
|
self.current_controllers[k] = v
|
||||||
|
return self.current_controllers
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerManager(TaskObject):
|
||||||
|
def __init__(self, server, port, retry_master):
|
||||||
|
self.server = server
|
||||||
|
self.port = port
|
||||||
|
self.retry_master = retry_master
|
||||||
|
self.controller_db = ControllerDB()
|
||||||
|
|
||||||
|
async def _do(self):
|
||||||
|
try:
|
||||||
|
subscriber = Subscriber("devices",
|
||||||
|
self.controller_db.sync_struct_init)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
def set_host_filter():
|
||||||
|
s = subscriber.writer.get_extra_info("socket")
|
||||||
|
localhost = s.getsockname()[0]
|
||||||
|
self.controller_db.set_host_filter(localhost)
|
||||||
|
await subscriber.connect(self.server, self.port,
|
||||||
|
set_host_filter)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(subscriber.receive_task, None)
|
||||||
|
finally:
|
||||||
|
await subscriber.close()
|
||||||
|
except (ConnectionAbortedError, ConnectionError,
|
||||||
|
ConnectionRefusedError, ConnectionResetError) as e:
|
||||||
|
logger.warning("Connection to master failed (%s: %s)",
|
||||||
|
e.__class__.__name__, str(e))
|
||||||
|
else:
|
||||||
|
logger.warning("Connection to master lost")
|
||||||
|
logger.warning("Retrying in %.1f seconds", self.retry_master)
|
||||||
|
await asyncio.sleep(self.retry_master)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
await self.controller_db.current_controllers.shutdown()
|
||||||
|
|
||||||
|
def retry_now(self, k):
|
||||||
|
"""If a controller is disabled and pending retry, perform that retry
|
||||||
|
now."""
|
||||||
|
self.controller_db.current_controllers.active[k].retry_now.notify()
|
|
@ -5,249 +5,13 @@ import atexit
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
|
||||||
import shlex
|
|
||||||
import socket
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from artiq.protocols.sync_struct import Subscriber
|
from artiq.protocols.pc_rpc import Server
|
||||||
from artiq.protocols.pc_rpc import AsyncioClient, Server
|
from artiq.protocols.logging import LogForwarder, SourceFilter
|
||||||
from artiq.protocols.logging import (LogForwarder, LogParser,
|
from artiq.tools import (simple_network_args, atexit_register_coroutine,
|
||||||
SourceFilter)
|
bind_address_from_args)
|
||||||
from artiq.tools import *
|
from artiq.devices.ctlmgr import ControllerManager
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Controller:
|
|
||||||
def __init__(self, name, ddb_entry):
|
|
||||||
self.name = name
|
|
||||||
self.command = ddb_entry["command"]
|
|
||||||
self.retry_timer = ddb_entry.get("retry_timer", 5)
|
|
||||||
self.retry_timer_backoff = ddb_entry.get("retry_timer_backoff", 1.1)
|
|
||||||
|
|
||||||
self.host = ddb_entry["host"]
|
|
||||||
self.port = ddb_entry["port"]
|
|
||||||
self.ping_timer = ddb_entry.get("ping_timer", 30)
|
|
||||||
self.ping_timeout = ddb_entry.get("ping_timeout", 30)
|
|
||||||
self.term_timeout = ddb_entry.get("term_timeout", 30)
|
|
||||||
|
|
||||||
self.retry_timer_cur = self.retry_timer
|
|
||||||
self.retry_now = Condition()
|
|
||||||
self.process = None
|
|
||||||
self.launch_task = asyncio.Task(self.launcher())
|
|
||||||
|
|
||||||
async def end(self):
|
|
||||||
self.launch_task.cancel()
|
|
||||||
await asyncio.wait_for(self.launch_task, None)
|
|
||||||
|
|
||||||
async def _call_controller(self, method):
|
|
||||||
remote = AsyncioClient()
|
|
||||||
await remote.connect_rpc(self.host, self.port, None)
|
|
||||||
try:
|
|
||||||
targets, _ = remote.get_rpc_id()
|
|
||||||
remote.select_rpc_target(targets[0])
|
|
||||||
r = await getattr(remote, method)()
|
|
||||||
finally:
|
|
||||||
remote.close_rpc()
|
|
||||||
return r
|
|
||||||
|
|
||||||
async def _ping(self):
|
|
||||||
try:
|
|
||||||
ok = await asyncio.wait_for(self._call_controller("ping"),
|
|
||||||
self.ping_timeout)
|
|
||||||
if ok:
|
|
||||||
self.retry_timer_cur = self.retry_timer
|
|
||||||
return ok
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _wait_and_ping(self):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.process.wait(),
|
|
||||||
self.ping_timer)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.debug("pinging controller %s", self.name)
|
|
||||||
ok = await self._ping()
|
|
||||||
if not ok:
|
|
||||||
logger.warning("Controller %s ping failed", self.name)
|
|
||||||
await self._terminate()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _get_log_source(self):
|
|
||||||
return "controller({})".format(self.name)
|
|
||||||
|
|
||||||
async def launcher(self):
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
logger.info("Starting controller %s with command: %s",
|
|
||||||
self.name, self.command)
|
|
||||||
try:
|
|
||||||
self.process = await asyncio.create_subprocess_exec(
|
|
||||||
*shlex.split(self.command),
|
|
||||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
||||||
asyncio.ensure_future(
|
|
||||||
LogParser(self._get_log_source).stream_task(
|
|
||||||
self.process.stdout))
|
|
||||||
asyncio.ensure_future(
|
|
||||||
LogParser(self._get_log_source).stream_task(
|
|
||||||
self.process.stderr))
|
|
||||||
await self._wait_and_ping()
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.warning("Controller %s failed to start", self.name)
|
|
||||||
else:
|
|
||||||
logger.warning("Controller %s exited", self.name)
|
|
||||||
logger.warning("Restarting in %.1f seconds",
|
|
||||||
self.retry_timer_cur)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.retry_now.wait(),
|
|
||||||
self.retry_timer_cur)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
self.retry_timer_cur *= self.retry_timer_backoff
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
await self._terminate()
|
|
||||||
|
|
||||||
async def _terminate(self):
|
|
||||||
logger.info("Terminating controller %s", self.name)
|
|
||||||
if self.process is not None and self.process.returncode is None:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._call_controller("terminate"),
|
|
||||||
self.term_timeout)
|
|
||||||
except:
|
|
||||||
logger.warning("Controller %s did not respond to terminate "
|
|
||||||
"command, killing", self.name)
|
|
||||||
try:
|
|
||||||
self.process.kill()
|
|
||||||
except ProcessLookupError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.process.wait(),
|
|
||||||
self.term_timeout)
|
|
||||||
except:
|
|
||||||
logger.warning("Controller %s failed to exit, killing",
|
|
||||||
self.name)
|
|
||||||
try:
|
|
||||||
self.process.kill()
|
|
||||||
except ProcessLookupError:
|
|
||||||
pass
|
|
||||||
await self.process.wait()
|
|
||||||
logger.debug("Controller %s terminated", self.name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ip_addresses(host):
|
|
||||||
try:
|
|
||||||
addrinfo = socket.getaddrinfo(host, None)
|
|
||||||
except:
|
|
||||||
return set()
|
|
||||||
return {info[4][0] for info in addrinfo}
|
|
||||||
|
|
||||||
|
|
||||||
class Controllers:
|
|
||||||
def __init__(self):
|
|
||||||
self.host_filter = None
|
|
||||||
self.active_or_queued = set()
|
|
||||||
self.queue = asyncio.Queue()
|
|
||||||
self.active = dict()
|
|
||||||
self.process_task = asyncio.Task(self._process())
|
|
||||||
|
|
||||||
async def _process(self):
|
|
||||||
while True:
|
|
||||||
action, param = await self.queue.get()
|
|
||||||
if action == "set":
|
|
||||||
k, ddb_entry = param
|
|
||||||
if k in self.active:
|
|
||||||
await self.active[k].end()
|
|
||||||
self.active[k] = Controller(k, ddb_entry)
|
|
||||||
elif action == "del":
|
|
||||||
await self.active[param].end()
|
|
||||||
del self.active[param]
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def __setitem__(self, k, v):
|
|
||||||
if (isinstance(v, dict) and v["type"] == "controller"
|
|
||||||
and self.host_filter in get_ip_addresses(v["host"])):
|
|
||||||
v["command"] = v["command"].format(name=k,
|
|
||||||
bind=self.host_filter,
|
|
||||||
port=v["port"])
|
|
||||||
self.queue.put_nowait(("set", (k, v)))
|
|
||||||
self.active_or_queued.add(k)
|
|
||||||
|
|
||||||
def __delitem__(self, k):
|
|
||||||
if k in self.active_or_queued:
|
|
||||||
self.queue.put_nowait(("del", k))
|
|
||||||
self.active_or_queued.remove(k)
|
|
||||||
|
|
||||||
def delete_all(self):
|
|
||||||
for name in set(self.active_or_queued):
|
|
||||||
del self[name]
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
self.process_task.cancel()
|
|
||||||
for c in self.active.values():
|
|
||||||
await c.end()
|
|
||||||
|
|
||||||
|
|
||||||
class ControllerDB:
|
|
||||||
def __init__(self):
|
|
||||||
self.current_controllers = Controllers()
|
|
||||||
|
|
||||||
def set_host_filter(self, host_filter):
|
|
||||||
self.current_controllers.host_filter = host_filter
|
|
||||||
|
|
||||||
def sync_struct_init(self, init):
|
|
||||||
if self.current_controllers is not None:
|
|
||||||
self.current_controllers.delete_all()
|
|
||||||
for k, v in init.items():
|
|
||||||
self.current_controllers[k] = v
|
|
||||||
return self.current_controllers
|
|
||||||
|
|
||||||
|
|
||||||
class ControllerManager(TaskObject):
|
|
||||||
def __init__(self, server, port, retry_master):
|
|
||||||
self.server = server
|
|
||||||
self.port = port
|
|
||||||
self.retry_master = retry_master
|
|
||||||
self.controller_db = ControllerDB()
|
|
||||||
|
|
||||||
async def _do(self):
|
|
||||||
try:
|
|
||||||
subscriber = Subscriber("devices",
|
|
||||||
self.controller_db.sync_struct_init)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
def set_host_filter():
|
|
||||||
s = subscriber.writer.get_extra_info("socket")
|
|
||||||
localhost = s.getsockname()[0]
|
|
||||||
self.controller_db.set_host_filter(localhost)
|
|
||||||
await subscriber.connect(self.server, self.port,
|
|
||||||
set_host_filter)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(subscriber.receive_task, None)
|
|
||||||
finally:
|
|
||||||
await subscriber.close()
|
|
||||||
except (ConnectionAbortedError, ConnectionError,
|
|
||||||
ConnectionRefusedError, ConnectionResetError) as e:
|
|
||||||
logger.warning("Connection to master failed (%s: %s)",
|
|
||||||
e.__class__.__name__, str(e))
|
|
||||||
else:
|
|
||||||
logger.warning("Connection to master lost")
|
|
||||||
logger.warning("Retrying in %.1f seconds", self.retry_master)
|
|
||||||
await asyncio.sleep(self.retry_master)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
await self.controller_db.current_controllers.shutdown()
|
|
||||||
|
|
||||||
def retry_now(self, k):
|
|
||||||
"""If a controller is disabled and pending retry, perform that retry
|
|
||||||
now."""
|
|
||||||
self.controller_db.current_controllers.active[k].retry_now.notify()
|
|
||||||
|
|
||||||
|
|
||||||
def get_argparser():
|
def get_argparser():
|
||||||
|
@ -280,7 +44,8 @@ def main():
|
||||||
|
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
root_logger.setLevel(logging.NOTSET)
|
root_logger.setLevel(logging.NOTSET)
|
||||||
source_adder = SourceFilter(logging.WARNING + args.quiet*10 - args.verbose*10,
|
source_adder = SourceFilter(logging.WARNING +
|
||||||
|
args.quiet*10 - args.verbose*10,
|
||||||
"ctlmgr({})".format(platform.node()))
|
"ctlmgr({})".format(platform.node()))
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(logging.Formatter(
|
console_handler.setFormatter(logging.Formatter(
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from artiq.devices.ctlmgr import Controllers
|
||||||
|
from artiq.protocols.pc_rpc import AsyncioClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
if os.name == "nt":
|
||||||
|
self.loop = asyncio.ProactorEventLoop()
|
||||||
|
else:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self.addCleanup(self.loop.close)
|
||||||
|
|
||||||
|
self.controllers = Controllers()
|
||||||
|
self.controllers.host_filter = "::1"
|
||||||
|
self.addCleanup(
|
||||||
|
lambda: self.loop.run_until_complete(self.controllers.shutdown()))
|
||||||
|
|
||||||
|
async def start(self, name, entry):
|
||||||
|
self.controllers[name] = entry
|
||||||
|
await self.controllers.queue.join()
|
||||||
|
await self.wait_for_ping(entry["host"], entry["port"])
|
||||||
|
|
||||||
|
async def get_client(self, host, port):
|
||||||
|
remote = AsyncioClient()
|
||||||
|
await remote.connect_rpc(host, port, None)
|
||||||
|
targets, _ = remote.get_rpc_id()
|
||||||
|
remote.select_rpc_target(targets[0])
|
||||||
|
self.addCleanup(remote.close_rpc)
|
||||||
|
return remote
|
||||||
|
|
||||||
|
async def wait_for_ping(self, host, port, retries=5, timeout=2):
|
||||||
|
dt = timeout/retries
|
||||||
|
while timeout > 0:
|
||||||
|
try:
|
||||||
|
remote = await self.get_client(host, port)
|
||||||
|
ok = await asyncio.wait_for(remote.ping(), dt)
|
||||||
|
if not ok:
|
||||||
|
raise ValueError("unexcepted ping() response from "
|
||||||
|
"controller: `{}`".format(ok))
|
||||||
|
return ok
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
timeout -= dt
|
||||||
|
except (ConnectionAbortedError, ConnectionError,
|
||||||
|
ConnectionRefusedError, ConnectionResetError):
|
||||||
|
await asyncio.sleep(dt)
|
||||||
|
timeout -= dt
|
||||||
|
raise asyncio.TimeoutError
|
||||||
|
|
||||||
|
def test_start_ping_stop_controller(self):
|
||||||
|
entry = {
|
||||||
|
"type": "controller",
|
||||||
|
"host": "::1",
|
||||||
|
"port": 3253,
|
||||||
|
"command": "lda_controller -p {port} --bind {bind} "
|
||||||
|
"--no-localhost-bind --simulation",
|
||||||
|
}
|
||||||
|
async def test():
|
||||||
|
await self.start("lda_sim", entry)
|
||||||
|
remote = await self.get_client(entry["host"], entry["port"])
|
||||||
|
await remote.close()
|
||||||
|
|
||||||
|
self.loop.run_until_complete(test())
|
|
@ -5,12 +5,13 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import shlex
|
||||||
|
import time
|
||||||
|
|
||||||
from artiq.experiment import *
|
|
||||||
from artiq.master.databases import DeviceDB, DatasetDB
|
from artiq.master.databases import DeviceDB, DatasetDB
|
||||||
from artiq.master.worker_db import DeviceManager, DatasetManager
|
from artiq.master.worker_db import DeviceManager, DatasetManager
|
||||||
from artiq.coredevice.core import CompileError
|
from artiq.coredevice.core import CompileError
|
||||||
from artiq.protocols import pyon
|
|
||||||
from artiq.frontend.artiq_run import DummyScheduler
|
from artiq.frontend.artiq_run import DummyScheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,28 +19,49 @@ artiq_root = os.getenv("ARTIQ_ROOT")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_from_ddb(*path, default="skip"):
|
@unittest.skipUnless(artiq_root, "no ARTIQ_ROOT")
|
||||||
if not artiq_root:
|
class ControllerCase(unittest.TestCase):
|
||||||
raise unittest.SkipTest("no ARTIQ_ROOT")
|
def setUp(self):
|
||||||
v = pyon.load_file(os.path.join(artiq_root, "device_db.pyon"))
|
self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon"))
|
||||||
|
self.device_mgr = DeviceManager(self.device_db)
|
||||||
|
self.controllers = {}
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.device_mgr.close_devices()
|
||||||
|
for name in list(self.controllers):
|
||||||
|
self.stop_controller(name)
|
||||||
|
|
||||||
|
def start_controller(self, name, sleep=1):
|
||||||
try:
|
try:
|
||||||
for p in path:
|
entry = self.device_db.get(name)
|
||||||
v = v[p]
|
|
||||||
return v.read
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if default == "skip":
|
raise unittest.SkipTest(
|
||||||
raise unittest.SkipTest("device db path {} not found".format(path))
|
"controller `{}` not found".format(name))
|
||||||
else:
|
entry["command"] = entry["command"].format(
|
||||||
return default
|
name=name, bind=entry["host"], port=entry["port"])
|
||||||
|
proc = subprocess.Popen(shlex.split(entry["command"]))
|
||||||
|
self.controllers[name] = entry, proc
|
||||||
|
time.sleep(sleep)
|
||||||
|
|
||||||
|
def stop_controller(self, name, default_timeout=1):
|
||||||
|
entry, proc = self.controllers[name]
|
||||||
|
t = entry.get("term_timeout", default_timeout)
|
||||||
|
try:
|
||||||
|
proc.wait(t)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
proc.wait(t)
|
||||||
|
del self.controllers[name]
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(artiq_root, "no ARTIQ_ROOT")
|
@unittest.skipUnless(artiq_root, "no ARTIQ_ROOT")
|
||||||
class ExperimentCase(unittest.TestCase):
|
class ExperimentCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon"))
|
self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon"))
|
||||||
self.dataset_db = DatasetDB(os.path.join(artiq_root, "dataset_db.pyon"))
|
self.dataset_db = DatasetDB(
|
||||||
self.device_mgr = DeviceManager(self.device_db,
|
os.path.join(artiq_root, "dataset_db.pyon"))
|
||||||
virtual_devices={"scheduler": DummyScheduler()})
|
self.device_mgr = DeviceManager(
|
||||||
|
self.device_db, virtual_devices={"scheduler": DummyScheduler()})
|
||||||
self.dataset_mgr = DatasetManager(self.dataset_db)
|
self.dataset_mgr = DatasetManager(self.dataset_db)
|
||||||
|
|
||||||
def create(self, cls, **kwargs):
|
def create(self, cls, **kwargs):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from artiq.devices.lda.driver import Lda, Ldasim
|
from artiq.devices.lda.driver import Ldasim
|
||||||
from artiq.language.units import dB
|
from artiq.language.units import dB
|
||||||
from artiq.test.hardware_testbench import get_from_ddb
|
from artiq.test.hardware_testbench import ControllerCase
|
||||||
|
|
||||||
|
|
||||||
class GenericLdaTest:
|
class GenericLdaTest:
|
||||||
|
@ -13,14 +13,15 @@ class GenericLdaTest:
|
||||||
for i in test_vector:
|
for i in test_vector:
|
||||||
with self.subTest(i=i):
|
with self.subTest(i=i):
|
||||||
self.cont.set_attenuation(i)
|
self.cont.set_attenuation(i)
|
||||||
self.assertEqual(i, self.cont.get_attenuation())
|
j = self.cont.get_attenuation()
|
||||||
|
self.assertEqual(i, j)
|
||||||
|
|
||||||
|
|
||||||
class TestLda(GenericLdaTest, unittest.TestCase):
|
class TestLda(ControllerCase, GenericLdaTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
lda_serial = get_from_ddb("lda", "device")
|
ControllerCase.setUp(self)
|
||||||
lda_product = get_from_ddb("lda", "product")
|
self.start_controller("lda")
|
||||||
self.cont = Lda(serial=lda_serial, product=lda_product)
|
self.cont = self.device_mgr.get("lda")
|
||||||
|
|
||||||
|
|
||||||
class TestLdaSim(GenericLdaTest, unittest.TestCase):
|
class TestLdaSim(GenericLdaTest, unittest.TestCase):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from artiq.devices.novatech409b.driver import Novatech409B
|
from artiq.devices.novatech409b.driver import Novatech409B
|
||||||
from artiq.test.hardware_testbench import get_from_ddb
|
from artiq.test.hardware_testbench import ControllerCase
|
||||||
|
|
||||||
|
|
||||||
class GenericNovatech409BTest:
|
class GenericNovatech409BTest:
|
||||||
|
@ -20,10 +20,11 @@ class GenericNovatech409BTest:
|
||||||
self.assertEqual(r[0:23], "00989680 2000 01F5 0000")
|
self.assertEqual(r[0:23], "00989680 2000 01F5 0000")
|
||||||
|
|
||||||
|
|
||||||
class TestNovatech409B(GenericNovatech409BTest, unittest.TestCase):
|
class TestNovatech409B(GenericNovatech409BTest, ControllerCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
novatech409b_device = get_from_ddb("novatech409b", "device")
|
ControllerCase.setUp(self)
|
||||||
self.driver = Novatech409B(novatech409b_device)
|
self.start_controller("novatech409b")
|
||||||
|
self.driver = self.device_mgr.get("novatech409b")
|
||||||
|
|
||||||
|
|
||||||
class TestNovatech409BSim(GenericNovatech409BTest, unittest.TestCase):
|
class TestNovatech409BSim(GenericNovatech409BTest, unittest.TestCase):
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
|
|
||||||
from artiq.devices.thorlabs_tcube.driver import Tdc, Tpz, TdcSim, TpzSim
|
from artiq.devices.thorlabs_tcube.driver import Tdc, Tpz, TdcSim, TpzSim
|
||||||
from artiq.language.units import V
|
from artiq.language.units import V
|
||||||
from artiq.test.hardware_testbench import get_from_ddb
|
from artiq.test.hardware_testbench import ControllerCase
|
||||||
|
|
||||||
|
|
||||||
class GenericTdcTest:
|
class GenericTdcTest:
|
||||||
|
@ -131,10 +131,11 @@ class GenericTpzTest:
|
||||||
self.assertEqual(test_vector, self.cont.get_tpz_io_settings())
|
self.assertEqual(test_vector, self.cont.get_tpz_io_settings())
|
||||||
|
|
||||||
|
|
||||||
class TestTdc(unittest.TestCase, GenericTdcTest):
|
class TestTdc(ControllerCase, GenericTdcTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
tdc_serial = get_from_ddb("tdc", "device")
|
ControllerCase.setUp(self)
|
||||||
self.cont = Tdc(serial_dev=tdc_serial)
|
self.start_controller("tdc")
|
||||||
|
self.cont = self.device_mgr.get("tdc")
|
||||||
|
|
||||||
|
|
||||||
class TestTdcSim(unittest.TestCase, GenericTdcTest):
|
class TestTdcSim(unittest.TestCase, GenericTdcTest):
|
||||||
|
@ -142,10 +143,11 @@ class TestTdcSim(unittest.TestCase, GenericTdcTest):
|
||||||
self.cont = TdcSim()
|
self.cont = TdcSim()
|
||||||
|
|
||||||
|
|
||||||
class TestTpz(unittest.TestCase, GenericTpzTest):
|
class TestTpz(ControllerCase, GenericTpzTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
tpz_serial = get_from_ddb("tpz", "device")
|
ControllerCase.setUp(self)
|
||||||
self.cont = Tpz(serial_dev=tpz_serial)
|
self.start_controller("tpz")
|
||||||
|
self.cont = self.device_mgr.get("tpz")
|
||||||
|
|
||||||
|
|
||||||
class TestTpzSim(unittest.TestCase, GenericTpzTest):
|
class TestTpzSim(unittest.TestCase, GenericTpzTest):
|
||||||
|
|
Loading…
Reference in New Issue