protocols/pc_rpc: document coroutine methods, support locking

This commit is contained in:
Sebastien Bourdeauducq 2016-03-22 21:54:29 +08:00
parent 1c9b8a1d52
commit 2cbe47e26f
2 changed files with 70 additions and 49 deletions

View File

@ -95,7 +95,7 @@ def main():
"master_dataset_db": dataset_db,
"master_schedule": scheduler,
"master_experiment_db": experiment_db
})
}, allow_parallel=True)
loop.run_until_complete(server_control.start(
bind, args.port_control))
atexit_register_coroutine(server_control.stop)

View File

@ -433,6 +433,12 @@ class Server(_AsyncioServer):
simple cases: it allows new connections to be be accepted even when the
previous client failed to properly shut down its connection.
If a target method is a coroutine, it is awaited and its return value
is sent to the RPC client. If ``allow_parallel`` is true, multiple
target coroutines may be executed in parallel (one per RPC client),
otherwise a lock ensures that the calls from several clients are executed
sequentially.
:param targets: A dictionary of objects providing the RPC methods to be
exposed to the client. Keys are names identifying each object.
Clients select one of these objects using its name upon connection.
@ -442,14 +448,74 @@ class Server(_AsyncioServer):
``terminate`` method that unblocks any tasks waiting on
``wait_terminate``. This is useful to handle server termination
requests from clients.
:param allow_parallel: Allow concurrent asyncio calls to the target's
methods.
"""
def __init__(self, targets, description=None, builtin_terminate=False):
def __init__(self, targets, description=None, builtin_terminate=False,
allow_parallel=False):
_AsyncioServer.__init__(self)
self.targets = targets
self.description = description
self.builtin_terminate = builtin_terminate
if builtin_terminate:
self._terminate_request = asyncio.Event()
if allow_parallel:
self._noparallel = None
else:
self._noparallel = asyncio.Lock()
async def _process_action(self, target, obj):
if self._noparallel is not None:
await self._noparallel.acquire()
try:
if obj["action"] == "get_rpc_method_list":
members = inspect.getmembers(target, inspect.ismethod)
doc = {
"docstring": inspect.getdoc(target),
"methods": {}
}
for name, method in members:
if name.startswith("_"):
continue
method = getattr(target, name)
argspec = inspect.getfullargspec(method)
doc["methods"][name] = (dict(argspec._asdict()),
inspect.getdoc(method))
if self.builtin_terminate:
doc["methods"]["terminate"] = (
{
"args": ["self"],
"defaults": None,
"varargs": None,
"varkw": None,
"kwonlyargs": [],
"kwonlydefaults": [],
},
"Terminate the server.")
return {"status": "ok", "ret": doc}
elif obj["action"] == "call":
logger.debug("calling %s", _PrettyPrintCall(obj))
if (self.builtin_terminate and obj["name"] ==
"terminate"):
self._terminate_request.set()
return {"status": "ok", "ret": None}
else:
method = getattr(target, obj["name"])
ret = method(*obj["args"], **obj["kwargs"])
if inspect.iscoroutine(ret):
ret = await ret
return {"status": "ok", "ret": ret}
else:
raise ValueError("Unknown action: {}"
.format(obj["action"]))
except asyncio.CancelledError:
raise
except:
return {"status": "failed",
"message": traceback.format_exc()}
finally:
if self._noparallel is not None:
self._noparallel.release()
async def _handle_connection_cr(self, reader, writer):
try:
@ -476,53 +542,8 @@ class Server(_AsyncioServer):
line = await reader.readline()
if not line:
break
obj = pyon.decode(line.decode())
try:
if obj["action"] == "get_rpc_method_list":
members = inspect.getmembers(target, inspect.ismethod)
doc = {
"docstring": inspect.getdoc(target),
"methods": {}
}
for name, method in members:
if name.startswith("_"):
continue
method = getattr(target, name)
argspec = inspect.getfullargspec(method)
doc["methods"][name] = (dict(argspec._asdict()),
inspect.getdoc(method))
if self.builtin_terminate:
doc["methods"]["terminate"] = (
{
"args": ["self"],
"defaults": None,
"varargs": None,
"varkw": None,
"kwonlyargs": [],
"kwonlydefaults": [],
},
"Terminate the server.")
obj = {"status": "ok", "ret": doc}
elif obj["action"] == "call":
logger.debug("calling %s", _PrettyPrintCall(obj))
if (self.builtin_terminate and obj["name"] ==
"terminate"):
self._terminate_request.set()
obj = {"status": "ok", "ret": None}
else:
method = getattr(target, obj["name"])
ret = method(*obj["args"], **obj["kwargs"])
if inspect.iscoroutine(ret):
ret = await ret
obj = {"status": "ok", "ret": ret}
else:
raise ValueError("Unknown action: {}"
.format(obj["action"]))
except Exception:
obj = {"status": "failed",
"message": traceback.format_exc()}
line = pyon.encode(obj) + "\n"
writer.write(line.encode())
reply = await self._process_action(target, pyon.decode(line.decode()))
writer.write((pyon.encode(reply) + "\n").encode())
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError):
# May happens on Windows when client disconnects
pass