protocols/pc_rpc: document coroutine methods, support locking

This commit is contained in:
Sebastien Bourdeauducq 2016-03-22 21:54:29 +08:00
parent 1c9b8a1d52
commit 2cbe47e26f
2 changed files with 70 additions and 49 deletions

View File

@ -95,7 +95,7 @@ def main():
"master_dataset_db": dataset_db, "master_dataset_db": dataset_db,
"master_schedule": scheduler, "master_schedule": scheduler,
"master_experiment_db": experiment_db "master_experiment_db": experiment_db
}) }, allow_parallel=True)
loop.run_until_complete(server_control.start( loop.run_until_complete(server_control.start(
bind, args.port_control)) bind, args.port_control))
atexit_register_coroutine(server_control.stop) atexit_register_coroutine(server_control.stop)

View File

@ -433,6 +433,12 @@ class Server(_AsyncioServer):
simple cases: it allows new connections to be be accepted even when the simple cases: it allows new connections to be be accepted even when the
previous client failed to properly shut down its connection. previous client failed to properly shut down its connection.
If a target method is a coroutine, it is awaited and its return value
is sent to the RPC client. If ``allow_parallel`` is true, multiple
target coroutines may be executed in parallel (one per RPC client),
otherwise a lock ensures that the calls from several clients are executed
sequentially.
:param targets: A dictionary of objects providing the RPC methods to be :param targets: A dictionary of objects providing the RPC methods to be
exposed to the client. Keys are names identifying each object. exposed to the client. Keys are names identifying each object.
Clients select one of these objects using its name upon connection. Clients select one of these objects using its name upon connection.
@ -442,14 +448,74 @@ class Server(_AsyncioServer):
``terminate`` method that unblocks any tasks waiting on ``terminate`` method that unblocks any tasks waiting on
``wait_terminate``. This is useful to handle server termination ``wait_terminate``. This is useful to handle server termination
requests from clients. requests from clients.
:param allow_parallel: Allow concurrent asyncio calls to the target's
methods.
""" """
def __init__(self, targets, description=None, builtin_terminate=False): def __init__(self, targets, description=None, builtin_terminate=False,
allow_parallel=False):
_AsyncioServer.__init__(self) _AsyncioServer.__init__(self)
self.targets = targets self.targets = targets
self.description = description self.description = description
self.builtin_terminate = builtin_terminate self.builtin_terminate = builtin_terminate
if builtin_terminate: if builtin_terminate:
self._terminate_request = asyncio.Event() self._terminate_request = asyncio.Event()
if allow_parallel:
self._noparallel = None
else:
self._noparallel = asyncio.Lock()
async def _process_action(self, target, obj):
if self._noparallel is not None:
await self._noparallel.acquire()
try:
if obj["action"] == "get_rpc_method_list":
members = inspect.getmembers(target, inspect.ismethod)
doc = {
"docstring": inspect.getdoc(target),
"methods": {}
}
for name, method in members:
if name.startswith("_"):
continue
method = getattr(target, name)
argspec = inspect.getfullargspec(method)
doc["methods"][name] = (dict(argspec._asdict()),
inspect.getdoc(method))
if self.builtin_terminate:
doc["methods"]["terminate"] = (
{
"args": ["self"],
"defaults": None,
"varargs": None,
"varkw": None,
"kwonlyargs": [],
"kwonlydefaults": [],
},
"Terminate the server.")
return {"status": "ok", "ret": doc}
elif obj["action"] == "call":
logger.debug("calling %s", _PrettyPrintCall(obj))
if (self.builtin_terminate and obj["name"] ==
"terminate"):
self._terminate_request.set()
return {"status": "ok", "ret": None}
else:
method = getattr(target, obj["name"])
ret = method(*obj["args"], **obj["kwargs"])
if inspect.iscoroutine(ret):
ret = await ret
return {"status": "ok", "ret": ret}
else:
raise ValueError("Unknown action: {}"
.format(obj["action"]))
except asyncio.CancelledError:
raise
except:
return {"status": "failed",
"message": traceback.format_exc()}
finally:
if self._noparallel is not None:
self._noparallel.release()
async def _handle_connection_cr(self, reader, writer): async def _handle_connection_cr(self, reader, writer):
try: try:
@ -476,53 +542,8 @@ class Server(_AsyncioServer):
line = await reader.readline() line = await reader.readline()
if not line: if not line:
break break
obj = pyon.decode(line.decode()) reply = await self._process_action(target, pyon.decode(line.decode()))
try: writer.write((pyon.encode(reply) + "\n").encode())
if obj["action"] == "get_rpc_method_list":
members = inspect.getmembers(target, inspect.ismethod)
doc = {
"docstring": inspect.getdoc(target),
"methods": {}
}
for name, method in members:
if name.startswith("_"):
continue
method = getattr(target, name)
argspec = inspect.getfullargspec(method)
doc["methods"][name] = (dict(argspec._asdict()),
inspect.getdoc(method))
if self.builtin_terminate:
doc["methods"]["terminate"] = (
{
"args": ["self"],
"defaults": None,
"varargs": None,
"varkw": None,
"kwonlyargs": [],
"kwonlydefaults": [],
},
"Terminate the server.")
obj = {"status": "ok", "ret": doc}
elif obj["action"] == "call":
logger.debug("calling %s", _PrettyPrintCall(obj))
if (self.builtin_terminate and obj["name"] ==
"terminate"):
self._terminate_request.set()
obj = {"status": "ok", "ret": None}
else:
method = getattr(target, obj["name"])
ret = method(*obj["args"], **obj["kwargs"])
if inspect.iscoroutine(ret):
ret = await ret
obj = {"status": "ok", "ret": ret}
else:
raise ValueError("Unknown action: {}"
.format(obj["action"]))
except Exception:
obj = {"status": "failed",
"message": traceback.format_exc()}
line = pyon.encode(obj) + "\n"
writer.write(line.encode())
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError): except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError):
# May happens on Windows when client disconnects # May happens on Windows when client disconnects
pass pass