forked from M-Labs/artiq
pc_rpc: asyncio client
This commit is contained in:
parent
650baa9fc1
commit
f9dd5682ee
|
@ -136,6 +136,102 @@ class Client:
|
||||||
return proxy
|
return proxy
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncioClient:
|
||||||
|
"""This class is similar to :class:`artiq.management.pc_rpc.Client`, but
|
||||||
|
uses ``asyncio`` instead of blocking calls.
|
||||||
|
|
||||||
|
All RPC methods are coroutines.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.__lock = asyncio.Lock()
|
||||||
|
self.__reader = None
|
||||||
|
self.__writer = None
|
||||||
|
self.__target_names = None
|
||||||
|
self.__id_parameters = None
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def connect_rpc(self, host, port, target_name):
|
||||||
|
"""Connects to the server. This cannot be done in __init__ because
|
||||||
|
this method is a coroutine. See ``Client`` for a description of the
|
||||||
|
parameters.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.__reader, self.__writer = \
|
||||||
|
yield from asyncio.open_connection(host, port)
|
||||||
|
try:
|
||||||
|
self.__writer.write(_init_string)
|
||||||
|
server_identification = yield from self.__recv()
|
||||||
|
self.__target_names = server_identification["targets"]
|
||||||
|
self.__id_parameters = server_identification["parameters"]
|
||||||
|
if target_name is not None:
|
||||||
|
self.select_rpc_target(target_name)
|
||||||
|
except:
|
||||||
|
self.close_rpc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def select_rpc_target(self, target_name):
|
||||||
|
"""Selects a RPC target by name. This function should be called
|
||||||
|
exactly once if the connection was created with ``target_name=None``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if target_name not in self.__target_names:
|
||||||
|
raise IncompatibleServer
|
||||||
|
self.__writer.write((target_name + "\n").encode())
|
||||||
|
|
||||||
|
def get_rpc_id(self):
|
||||||
|
"""Returns a tuple (target_names, id_parameters) containing the
|
||||||
|
identification information of the server.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return (self.__target_names, self.__id_parameters)
|
||||||
|
|
||||||
|
def close_rpc(self):
|
||||||
|
"""Closes the connection to the RPC server.
|
||||||
|
|
||||||
|
No further method calls should be done after this method is called.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.__writer.close()
|
||||||
|
self.__reader = None
|
||||||
|
self.__writer = None
|
||||||
|
self.__target_names = None
|
||||||
|
self.__id_parameters = None
|
||||||
|
|
||||||
|
def __send(self, obj):
|
||||||
|
line = pyon.encode(obj) + "\n"
|
||||||
|
self.__writer.write(line.encode())
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def __recv(self):
|
||||||
|
line = yield from self.__reader.readline()
|
||||||
|
return pyon.decode(line.decode())
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def __do_rpc(self, name, args, kwargs):
|
||||||
|
yield from self.__lock.acquire()
|
||||||
|
try:
|
||||||
|
obj = {"action": "call", "name": name,
|
||||||
|
"args": args, "kwargs": kwargs}
|
||||||
|
self.__send(obj)
|
||||||
|
|
||||||
|
obj = yield from self.__recv()
|
||||||
|
if obj["status"] == "ok":
|
||||||
|
return obj["ret"]
|
||||||
|
elif obj["status"] == "failed":
|
||||||
|
raise RemoteError(obj["message"])
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
@asyncio.coroutine
|
||||||
|
def proxy(*args, **kwargs):
|
||||||
|
return self.__do_rpc(name, args, kwargs)
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
|
||||||
class Server(AsyncioServer):
|
class Server(AsyncioServer):
|
||||||
"""This class creates a TCP server that handles requests coming from
|
"""This class creates a TCP server that handles requests coming from
|
||||||
``Client`` objects.
|
``Client`` objects.
|
||||||
|
|
|
@ -11,17 +11,26 @@ from artiq.management import pc_rpc
|
||||||
|
|
||||||
test_address = "::1"
|
test_address = "::1"
|
||||||
test_port = 7777
|
test_port = 7777
|
||||||
|
test_object = [5, 2.1, None, True, False,
|
||||||
|
{"a": 5, 2: np.linspace(0, 10, 1)},
|
||||||
|
(4, 5), (10,), "ab\nx\"'"]
|
||||||
|
|
||||||
|
|
||||||
class RPCCase(unittest.TestCase):
|
class RPCCase(unittest.TestCase):
|
||||||
def test_echo(self):
|
def _run_server_and_test(self, test):
|
||||||
# running this file outside of unittest starts the echo server
|
# running this file outside of unittest starts the echo server
|
||||||
with subprocess.Popen([sys.executable,
|
with subprocess.Popen([sys.executable,
|
||||||
sys.modules[__name__].__file__]) as proc:
|
sys.modules[__name__].__file__]) as proc:
|
||||||
try:
|
try:
|
||||||
test_object = [5, 2.1, None, True, False,
|
test()
|
||||||
{"a": 5, 2: np.linspace(0, 10, 1)},
|
finally:
|
||||||
(4, 5), (10,), "ab\nx\"'"]
|
try:
|
||||||
|
proc.wait(timeout=1)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _blocking_echo(self):
|
||||||
for attempt in range(100):
|
for attempt in range(100):
|
||||||
time.sleep(.2)
|
time.sleep(.2)
|
||||||
try:
|
try:
|
||||||
|
@ -33,18 +42,45 @@ class RPCCase(unittest.TestCase):
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
test_object_back = remote.echo(test_object)
|
test_object_back = remote.echo(test_object)
|
||||||
|
self.assertEqual(test_object, test_object_back)
|
||||||
with self.assertRaises(pc_rpc.RemoteError):
|
with self.assertRaises(pc_rpc.RemoteError):
|
||||||
remote.non_existing_method()
|
remote.non_existing_method()
|
||||||
remote.quit()
|
remote.quit()
|
||||||
finally:
|
finally:
|
||||||
remote.close_rpc()
|
remote.close_rpc()
|
||||||
finally:
|
|
||||||
|
def test_blocking_echo(self):
|
||||||
|
self._run_server_and_test(self._blocking_echo)
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def _asyncio_echo(self):
|
||||||
|
remote = pc_rpc.AsyncioClient()
|
||||||
|
for attempt in range(100):
|
||||||
|
yield from asyncio.sleep(.2)
|
||||||
try:
|
try:
|
||||||
proc.wait(timeout=1)
|
yield from remote.connect_rpc(test_address, test_port, "test")
|
||||||
except subprocess.TimeoutExpired:
|
except ConnectionRefusedError:
|
||||||
proc.kill()
|
pass
|
||||||
raise
|
else:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
test_object_back = yield from remote.echo(test_object)
|
||||||
self.assertEqual(test_object, test_object_back)
|
self.assertEqual(test_object, test_object_back)
|
||||||
|
with self.assertRaises(pc_rpc.RemoteError):
|
||||||
|
yield from remote.non_existing_method()
|
||||||
|
yield from remote.quit()
|
||||||
|
finally:
|
||||||
|
remote.close_rpc()
|
||||||
|
|
||||||
|
def _loop_asyncio_echo(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(self._asyncio_echo())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def test_asyncio_echo(self):
|
||||||
|
self._run_server_and_test(self._loop_asyncio_echo)
|
||||||
|
|
||||||
|
|
||||||
class Echo:
|
class Echo:
|
||||||
|
|
Loading…
Reference in New Issue