diff --git a/artiq/devices/ctlmgr.py b/artiq/devices/ctlmgr.py index 2c50c25df..416b9f7d1 100644 --- a/artiq/devices/ctlmgr.py +++ b/artiq/devices/ctlmgr.py @@ -35,20 +35,20 @@ class Controller: self.launch_task.cancel() await asyncio.wait_for(self.launch_task, None) - async def _call_controller(self, method): + async def call(self, method, *args, **kwargs): remote = AsyncioClient() await remote.connect_rpc(self.host, self.port, None) try: targets, _ = remote.get_rpc_id() remote.select_rpc_target(targets[0]) - r = await getattr(remote, method)() + r = await getattr(remote, method)(*args, **kwargs) finally: remote.close_rpc() return r async def _ping(self): try: - ok = await asyncio.wait_for(self._call_controller("ping"), + ok = await asyncio.wait_for(self.call("ping"), self.ping_timeout) if ok: self.retry_timer_cur = self.retry_timer @@ -120,7 +120,7 @@ class Controller: logger.info("Terminating controller %s", self.name) if self.process is not None and self.process.returncode is None: try: - await asyncio.wait_for(self._call_controller("terminate"), + await asyncio.wait_for(self.call("terminate"), self.term_timeout) except: logger.warning("Controller %s did not respond to terminate " @@ -172,6 +172,7 @@ class Controllers: del self.active[param] else: raise ValueError + self.queue.task_done() def __setitem__(self, k, v): if (isinstance(v, dict) and v["type"] == "controller" and diff --git a/artiq/test/hardware_testbench.py b/artiq/test/hardware_testbench.py index 721802b87..1633fbc38 100644 --- a/artiq/test/hardware_testbench.py +++ b/artiq/test/hardware_testbench.py @@ -5,41 +5,100 @@ import os import sys import unittest import logging +import asyncio +import atexit -from artiq.experiment import * from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.worker_db import DeviceManager, DatasetManager from artiq.coredevice.core import CompileError -from artiq.protocols import pyon from artiq.frontend.artiq_run import DummyScheduler +from artiq.devices.ctlmgr import Controllers artiq_root = os.getenv("ARTIQ_ROOT") logger = logging.getLogger(__name__) -def get_from_ddb(*path, default="skip"): - if not artiq_root: - raise unittest.SkipTest("no ARTIQ_ROOT") - v = pyon.load_file(os.path.join(artiq_root, "device_db.pyon")) - try: - for p in path: - v = v[p] - return v - except KeyError: - if default == "skip": - raise unittest.SkipTest("device db path {} not found".format(path)) +@unittest.skipUnless(artiq_root, "no ARTIQ_ROOT") +class ControllerCase(unittest.TestCase): + host_filter = "::1" + timeout = 2 + + def setUp(self): + if os.name == "nt": + self.loop = asyncio.ProactorEventLoop() else: - return default + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + atexit.register(self.loop.close) + + self.controllers = Controllers() + self.controllers.host_filter = self.host_filter + + self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon")) + self.device_mgr = DeviceManager(self.device_db) + + async def start(self, *names): + for name in names: + try: + self.controllers[name] = self.device_db.get(name) + except KeyError: + raise unittest.SkipTest( + "controller `{}` not found".format(name)) + await self.controllers.queue.join() + await asyncio.wait([asyncio.ensure_future(self.wait_for_ping(name)) + for name in names]) + + async def wait_for_ping(self, name, retries=5): + t = self.timeout + dt = t/retries + while t > 0: + try: + ok = await asyncio.wait_for( + self.controllers.active[name].call("ping"), dt) + if not ok: + raise ValueError("unexcepted ping() response from " + "controller `{}`: `{}`".format(name, ok)) + return ok + except asyncio.TimeoutError: + t -= dt + except (ConnectionAbortedError, ConnectionError, + ConnectionRefusedError, ConnectionResetError): + await asyncio.sleep(dt) + t -= dt + raise asyncio.TimeoutError + + +def with_controllers(*names): + def wrapper(func): + def inner(self): + try: + for name in names: + setattr(self, name, self.device_mgr.get(name)) + func(self) + finally: + self.device_mgr.close_devices() + + def wrapped_test(self): + try: + self.loop.run_until_complete(self.start(*names)) + self.loop.run_until_complete(self.loop.run_in_executor( + None, inner, self)) + finally: + self.loop.run_until_complete(self.controllers.shutdown()) + + return wrapped_test + return wrapper @unittest.skipUnless(artiq_root, "no ARTIQ_ROOT") class ExperimentCase(unittest.TestCase): def setUp(self): self.device_db = DeviceDB(os.path.join(artiq_root, "device_db.pyon")) - self.dataset_db = DatasetDB(os.path.join(artiq_root, "dataset_db.pyon")) - self.device_mgr = DeviceManager(self.device_db, - virtual_devices={"scheduler": DummyScheduler()}) + self.dataset_db = DatasetDB( + os.path.join(artiq_root, "dataset_db.pyon")) + self.device_mgr = DeviceManager( + self.device_db, virtual_devices={"scheduler": DummyScheduler()}) self.dataset_mgr = DatasetManager(self.dataset_db) def create(self, cls, **kwargs): diff --git a/artiq/test/lda.py b/artiq/test/lda.py index a477ea535..bf35ce8b9 100644 --- a/artiq/test/lda.py +++ b/artiq/test/lda.py @@ -1,8 +1,8 @@ import unittest -from artiq.devices.lda.driver import Lda, Ldasim +from artiq.devices.lda.driver import Ldasim from artiq.language.units import dB -from artiq.test.hardware_testbench import get_from_ddb +from artiq.test.hardware_testbench import ControllerCase, with_controllers class GenericLdaTest: @@ -17,11 +17,11 @@ class GenericLdaTest: self.assertEqual(i, j) -class TestLda(GenericLdaTest, unittest.TestCase): - def setUp(self): - lda_serial = get_from_ddb("lda", "device") - lda_product = get_from_ddb("lda", "product") - self.cont = Lda(serial=lda_serial, product=lda_product) +class TestLda(ControllerCase, GenericLdaTest): + @with_controllers("lda") + def test_attenuation(self): + self.cont = self.device_mgr.get("lda") + GenericLdaTest.test_attenuation(self) class TestLdaSim(GenericLdaTest, unittest.TestCase): diff --git a/artiq/test/novatech409b.py b/artiq/test/novatech409b.py index 314b1cb19..3f3eb12d7 100644 --- a/artiq/test/novatech409b.py +++ b/artiq/test/novatech409b.py @@ -1,7 +1,7 @@ import unittest from artiq.devices.novatech409b.driver import Novatech409B -from artiq.test.hardware_testbench import get_from_ddb +from artiq.test.hardware_testbench import ControllerCase, with_controllers class GenericNovatech409BTest: @@ -20,10 +20,11 @@ class GenericNovatech409BTest: self.assertEqual(r[0:23], "00989680 2000 01F5 0000") -class TestNovatech409B(GenericNovatech409BTest, unittest.TestCase): - def setUp(self): - novatech409b_device = get_from_ddb("novatech409b", "device") - self.driver = Novatech409B(novatech409b_device) +class TestNovatech409B(GenericNovatech409BTest, ControllerCase): + @with_controllers("novatech409b") + def test_parameters_readback(self): + self.driver = self.device_mgr.get("novatech409b") + GenericNovatech409BTest.test_parameters_readback(self) class TestNovatech409BSim(GenericNovatech409BTest, unittest.TestCase):