forked from M-Labs/artiq
pc_rpc: asyncio client
This commit is contained in:
parent
650baa9fc1
commit
f9dd5682ee
|
@ -136,6 +136,102 @@ class Client:
|
|||
return proxy
|
||||
|
||||
|
||||
class AsyncioClient:
|
||||
"""This class is similar to :class:`artiq.management.pc_rpc.Client`, but
|
||||
uses ``asyncio`` instead of blocking calls.
|
||||
|
||||
All RPC methods are coroutines.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self.__lock = asyncio.Lock()
|
||||
self.__reader = None
|
||||
self.__writer = None
|
||||
self.__target_names = None
|
||||
self.__id_parameters = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect_rpc(self, host, port, target_name):
|
||||
"""Connects to the server. This cannot be done in __init__ because
|
||||
this method is a coroutine. See ``Client`` for a description of the
|
||||
parameters.
|
||||
|
||||
"""
|
||||
self.__reader, self.__writer = \
|
||||
yield from asyncio.open_connection(host, port)
|
||||
try:
|
||||
self.__writer.write(_init_string)
|
||||
server_identification = yield from self.__recv()
|
||||
self.__target_names = server_identification["targets"]
|
||||
self.__id_parameters = server_identification["parameters"]
|
||||
if target_name is not None:
|
||||
self.select_rpc_target(target_name)
|
||||
except:
|
||||
self.close_rpc()
|
||||
raise
|
||||
|
||||
def select_rpc_target(self, target_name):
|
||||
"""Selects a RPC target by name. This function should be called
|
||||
exactly once if the connection was created with ``target_name=None``.
|
||||
|
||||
"""
|
||||
if target_name not in self.__target_names:
|
||||
raise IncompatibleServer
|
||||
self.__writer.write((target_name + "\n").encode())
|
||||
|
||||
def get_rpc_id(self):
|
||||
"""Returns a tuple (target_names, id_parameters) containing the
|
||||
identification information of the server.
|
||||
|
||||
"""
|
||||
return (self.__target_names, self.__id_parameters)
|
||||
|
||||
def close_rpc(self):
|
||||
"""Closes the connection to the RPC server.
|
||||
|
||||
No further method calls should be done after this method is called.
|
||||
|
||||
"""
|
||||
self.__writer.close()
|
||||
self.__reader = None
|
||||
self.__writer = None
|
||||
self.__target_names = None
|
||||
self.__id_parameters = None
|
||||
|
||||
def __send(self, obj):
|
||||
line = pyon.encode(obj) + "\n"
|
||||
self.__writer.write(line.encode())
|
||||
|
||||
@asyncio.coroutine
|
||||
def __recv(self):
|
||||
line = yield from self.__reader.readline()
|
||||
return pyon.decode(line.decode())
|
||||
|
||||
@asyncio.coroutine
|
||||
def __do_rpc(self, name, args, kwargs):
|
||||
yield from self.__lock.acquire()
|
||||
try:
|
||||
obj = {"action": "call", "name": name,
|
||||
"args": args, "kwargs": kwargs}
|
||||
self.__send(obj)
|
||||
|
||||
obj = yield from self.__recv()
|
||||
if obj["status"] == "ok":
|
||||
return obj["ret"]
|
||||
elif obj["status"] == "failed":
|
||||
raise RemoteError(obj["message"])
|
||||
else:
|
||||
raise ValueError
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def __getattr__(self, name):
|
||||
@asyncio.coroutine
|
||||
def proxy(*args, **kwargs):
|
||||
return self.__do_rpc(name, args, kwargs)
|
||||
return proxy
|
||||
|
||||
|
||||
class Server(AsyncioServer):
|
||||
"""This class creates a TCP server that handles requests coming from
|
||||
``Client`` objects.
|
||||
|
|
|
@ -11,17 +11,26 @@ from artiq.management import pc_rpc
|
|||
|
||||
test_address = "::1"
|
||||
test_port = 7777
|
||||
test_object = [5, 2.1, None, True, False,
|
||||
{"a": 5, 2: np.linspace(0, 10, 1)},
|
||||
(4, 5), (10,), "ab\nx\"'"]
|
||||
|
||||
|
||||
class RPCCase(unittest.TestCase):
|
||||
def test_echo(self):
|
||||
def _run_server_and_test(self, test):
|
||||
# running this file outside of unittest starts the echo server
|
||||
with subprocess.Popen([sys.executable,
|
||||
sys.modules[__name__].__file__]) as proc:
|
||||
try:
|
||||
test_object = [5, 2.1, None, True, False,
|
||||
{"a": 5, 2: np.linspace(0, 10, 1)},
|
||||
(4, 5), (10,), "ab\nx\"'"]
|
||||
test()
|
||||
finally:
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
raise
|
||||
|
||||
def _blocking_echo(self):
|
||||
for attempt in range(100):
|
||||
time.sleep(.2)
|
||||
try:
|
||||
|
@ -33,18 +42,45 @@ class RPCCase(unittest.TestCase):
|
|||
break
|
||||
try:
|
||||
test_object_back = remote.echo(test_object)
|
||||
self.assertEqual(test_object, test_object_back)
|
||||
with self.assertRaises(pc_rpc.RemoteError):
|
||||
remote.non_existing_method()
|
||||
remote.quit()
|
||||
finally:
|
||||
remote.close_rpc()
|
||||
finally:
|
||||
|
||||
def test_blocking_echo(self):
|
||||
self._run_server_and_test(self._blocking_echo)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _asyncio_echo(self):
|
||||
remote = pc_rpc.AsyncioClient()
|
||||
for attempt in range(100):
|
||||
yield from asyncio.sleep(.2)
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
raise
|
||||
yield from remote.connect_rpc(test_address, test_port, "test")
|
||||
except ConnectionRefusedError:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
try:
|
||||
test_object_back = yield from remote.echo(test_object)
|
||||
self.assertEqual(test_object, test_object_back)
|
||||
with self.assertRaises(pc_rpc.RemoteError):
|
||||
yield from remote.non_existing_method()
|
||||
yield from remote.quit()
|
||||
finally:
|
||||
remote.close_rpc()
|
||||
|
||||
def _loop_asyncio_echo(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self._asyncio_echo())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_asyncio_echo(self):
|
||||
self._run_server_and_test(self._loop_asyncio_echo)
|
||||
|
||||
|
||||
class Echo:
|
||||
|
|
Loading…
Reference in New Issue