diff --git a/artiq/frontend/artiq_master.py b/artiq/frontend/artiq_master.py index 1d3a8617e..94c8903c7 100755 --- a/artiq/frontend/artiq_master.py +++ b/artiq/frontend/artiq_master.py @@ -95,7 +95,7 @@ def main(): "master_dataset_db": dataset_db, "master_schedule": scheduler, "master_experiment_db": experiment_db - }) + }, allow_parallel=True) loop.run_until_complete(server_control.start( bind, args.port_control)) atexit_register_coroutine(server_control.stop) diff --git a/artiq/protocols/pc_rpc.py b/artiq/protocols/pc_rpc.py index c9f53b286..512237d60 100644 --- a/artiq/protocols/pc_rpc.py +++ b/artiq/protocols/pc_rpc.py @@ -433,6 +433,12 @@ class Server(_AsyncioServer): simple cases: it allows new connections to be be accepted even when the previous client failed to properly shut down its connection. + If a target method is a coroutine, it is awaited and its return value + is sent to the RPC client. If ``allow_parallel`` is true, multiple + target coroutines may be executed in parallel (one per RPC client), + otherwise a lock ensures that the calls from several clients are executed + sequentially. + :param targets: A dictionary of objects providing the RPC methods to be exposed to the client. Keys are names identifying each object. Clients select one of these objects using its name upon connection. @@ -442,14 +448,74 @@ class Server(_AsyncioServer): ``terminate`` method that unblocks any tasks waiting on ``wait_terminate``. This is useful to handle server termination requests from clients. + :param allow_parallel: Allow concurrent asyncio calls to the target's + methods. """ - def __init__(self, targets, description=None, builtin_terminate=False): + def __init__(self, targets, description=None, builtin_terminate=False, + allow_parallel=False): _AsyncioServer.__init__(self) self.targets = targets self.description = description self.builtin_terminate = builtin_terminate if builtin_terminate: self._terminate_request = asyncio.Event() + if allow_parallel: + self._noparallel = None + else: + self._noparallel = asyncio.Lock() + + async def _process_action(self, target, obj): + if self._noparallel is not None: + await self._noparallel.acquire() + try: + if obj["action"] == "get_rpc_method_list": + members = inspect.getmembers(target, inspect.ismethod) + doc = { + "docstring": inspect.getdoc(target), + "methods": {} + } + for name, method in members: + if name.startswith("_"): + continue + method = getattr(target, name) + argspec = inspect.getfullargspec(method) + doc["methods"][name] = (dict(argspec._asdict()), + inspect.getdoc(method)) + if self.builtin_terminate: + doc["methods"]["terminate"] = ( + { + "args": ["self"], + "defaults": None, + "varargs": None, + "varkw": None, + "kwonlyargs": [], + "kwonlydefaults": [], + }, + "Terminate the server.") + return {"status": "ok", "ret": doc} + elif obj["action"] == "call": + logger.debug("calling %s", _PrettyPrintCall(obj)) + if (self.builtin_terminate and obj["name"] == + "terminate"): + self._terminate_request.set() + return {"status": "ok", "ret": None} + else: + method = getattr(target, obj["name"]) + ret = method(*obj["args"], **obj["kwargs"]) + if inspect.iscoroutine(ret): + ret = await ret + return {"status": "ok", "ret": ret} + else: + raise ValueError("Unknown action: {}" + .format(obj["action"])) + except asyncio.CancelledError: + raise + except: + return {"status": "failed", + "message": traceback.format_exc()} + finally: + if self._noparallel is not None: + self._noparallel.release() async def _handle_connection_cr(self, reader, writer): try: @@ -476,53 +542,8 @@ class Server(_AsyncioServer): line = await reader.readline() if not line: break - obj = pyon.decode(line.decode()) - try: - if obj["action"] == "get_rpc_method_list": - members = inspect.getmembers(target, inspect.ismethod) - doc = { - "docstring": inspect.getdoc(target), - "methods": {} - } - for name, method in members: - if name.startswith("_"): - continue - method = getattr(target, name) - argspec = inspect.getfullargspec(method) - doc["methods"][name] = (dict(argspec._asdict()), - inspect.getdoc(method)) - if self.builtin_terminate: - doc["methods"]["terminate"] = ( - { - "args": ["self"], - "defaults": None, - "varargs": None, - "varkw": None, - "kwonlyargs": [], - "kwonlydefaults": [], - }, - "Terminate the server.") - obj = {"status": "ok", "ret": doc} - elif obj["action"] == "call": - logger.debug("calling %s", _PrettyPrintCall(obj)) - if (self.builtin_terminate and obj["name"] == - "terminate"): - self._terminate_request.set() - obj = {"status": "ok", "ret": None} - else: - method = getattr(target, obj["name"]) - ret = method(*obj["args"], **obj["kwargs"]) - if inspect.iscoroutine(ret): - ret = await ret - obj = {"status": "ok", "ret": ret} - else: - raise ValueError("Unknown action: {}" - .format(obj["action"])) - except Exception: - obj = {"status": "failed", - "message": traceback.format_exc()} - line = pyon.encode(obj) + "\n" - writer.write(line.encode()) + reply = await self._process_action(target, pyon.decode(line.decode())) + writer.write((pyon.encode(reply) + "\n").encode()) except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError): # May happens on Windows when client disconnects pass