diff --git a/artiq/management/pc_rpc.py b/artiq/management/pc_rpc.py index cb7bff0ed..7e2f109c9 100644 --- a/artiq/management/pc_rpc.py +++ b/artiq/management/pc_rpc.py @@ -136,6 +136,102 @@ class Client: return proxy +class AsyncioClient: + """This class is similar to :class:`artiq.management.pc_rpc.Client`, but + uses ``asyncio`` instead of blocking calls. + + All RPC methods are coroutines. + + """ + def __init__(self): + self.__lock = asyncio.Lock() + self.__reader = None + self.__writer = None + self.__target_names = None + self.__id_parameters = None + + @asyncio.coroutine + def connect_rpc(self, host, port, target_name): + """Connects to the server. This cannot be done in __init__ because + this method is a coroutine. See ``Client`` for a description of the + parameters. + + """ + self.__reader, self.__writer = \ + yield from asyncio.open_connection(host, port) + try: + self.__writer.write(_init_string) + server_identification = yield from self.__recv() + self.__target_names = server_identification["targets"] + self.__id_parameters = server_identification["parameters"] + if target_name is not None: + self.select_rpc_target(target_name) + except: + self.close_rpc() + raise + + def select_rpc_target(self, target_name): + """Selects a RPC target by name. This function should be called + exactly once if the connection was created with ``target_name=None``. + + """ + if target_name not in self.__target_names: + raise IncompatibleServer + self.__writer.write((target_name + "\n").encode()) + + def get_rpc_id(self): + """Returns a tuple (target_names, id_parameters) containing the + identification information of the server. + + """ + return (self.__target_names, self.__id_parameters) + + def close_rpc(self): + """Closes the connection to the RPC server. + + No further method calls should be done after this method is called. + + """ + self.__writer.close() + self.__reader = None + self.__writer = None + self.__target_names = None + self.__id_parameters = None + + def __send(self, obj): + line = pyon.encode(obj) + "\n" + self.__writer.write(line.encode()) + + @asyncio.coroutine + def __recv(self): + line = yield from self.__reader.readline() + return pyon.decode(line.decode()) + + @asyncio.coroutine + def __do_rpc(self, name, args, kwargs): + yield from self.__lock.acquire() + try: + obj = {"action": "call", "name": name, + "args": args, "kwargs": kwargs} + self.__send(obj) + + obj = yield from self.__recv() + if obj["status"] == "ok": + return obj["ret"] + elif obj["status"] == "failed": + raise RemoteError(obj["message"]) + else: + raise ValueError + finally: + self.__lock.release() + + def __getattr__(self, name): + @asyncio.coroutine + def proxy(*args, **kwargs): + return self.__do_rpc(name, args, kwargs) + return proxy + + class Server(AsyncioServer): """This class creates a TCP server that handles requests coming from ``Client`` objects. diff --git a/artiq/test/pc_rpc.py b/artiq/test/pc_rpc.py index 63f7b8263..60a07ed95 100644 --- a/artiq/test/pc_rpc.py +++ b/artiq/test/pc_rpc.py @@ -11,40 +11,76 @@ from artiq.management import pc_rpc test_address = "::1" test_port = 7777 +test_object = [5, 2.1, None, True, False, + {"a": 5, 2: np.linspace(0, 10, 1)}, + (4, 5), (10,), "ab\nx\"'"] class RPCCase(unittest.TestCase): - def test_echo(self): + def _run_server_and_test(self, test): # running this file outside of unittest starts the echo server with subprocess.Popen([sys.executable, sys.modules[__name__].__file__]) as proc: try: - test_object = [5, 2.1, None, True, False, - {"a": 5, 2: np.linspace(0, 10, 1)}, - (4, 5), (10,), "ab\nx\"'"] - for attempt in range(100): - time.sleep(.2) - try: - remote = pc_rpc.Client(test_address, test_port, - "test") - except ConnectionRefusedError: - pass - else: - break - try: - test_object_back = remote.echo(test_object) - with self.assertRaises(pc_rpc.RemoteError): - remote.non_existing_method() - remote.quit() - finally: - remote.close_rpc() + test() finally: try: proc.wait(timeout=1) except subprocess.TimeoutExpired: proc.kill() raise + + def _blocking_echo(self): + for attempt in range(100): + time.sleep(.2) + try: + remote = pc_rpc.Client(test_address, test_port, + "test") + except ConnectionRefusedError: + pass + else: + break + try: + test_object_back = remote.echo(test_object) self.assertEqual(test_object, test_object_back) + with self.assertRaises(pc_rpc.RemoteError): + remote.non_existing_method() + remote.quit() + finally: + remote.close_rpc() + + def test_blocking_echo(self): + self._run_server_and_test(self._blocking_echo) + + @asyncio.coroutine + def _asyncio_echo(self): + remote = pc_rpc.AsyncioClient() + for attempt in range(100): + yield from asyncio.sleep(.2) + try: + yield from remote.connect_rpc(test_address, test_port, "test") + except ConnectionRefusedError: + pass + else: + break + try: + test_object_back = yield from remote.echo(test_object) + self.assertEqual(test_object, test_object_back) + with self.assertRaises(pc_rpc.RemoteError): + yield from remote.non_existing_method() + yield from remote.quit() + finally: + remote.close_rpc() + + def _loop_asyncio_echo(self): + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(self._asyncio_echo()) + finally: + loop.close() + + def test_asyncio_echo(self): + self._run_server_and_test(self._loop_asyncio_echo) class Echo: