pc_rpc: autotarget support

This commit is contained in:
Sebastien Bourdeauducq 2015-10-18 14:34:30 +08:00
parent 661b9bfbfa
commit 5947f54855
2 changed files with 32 additions and 21 deletions

View File

@ -6,7 +6,7 @@ import sys
import numpy as np # Needed to use numpy in RPC call arguments on cmd line
import pprint
from artiq.protocols.pc_rpc import Client
from artiq.protocols.pc_rpc import AutoTarget, Client
def get_argparser():
@ -85,19 +85,9 @@ def main():
args = get_argparser().parse_args()
remote = Client(args.server, args.port, None)
targets, description = remote.get_rpc_id()
if args.action != "list-targets":
# If no target specified and remote has only one, then use this one.
# Exit otherwise.
if len(targets) > 1 and args.target is None:
print("Remote server has several targets, please supply one with "
"-t")
sys.exit(1)
elif args.target is None:
args.target = targets[0]
remote.select_rpc_target(args.target)
remote.select_rpc_target(AutoTarget)
if args.action == "list-targets":
list_targets(targets, description)

View File

@ -27,6 +27,12 @@ from artiq.protocols.asyncio_server import AsyncioServer as _AsyncioServer
logger = logging.getLogger(__name__)
class AutoTarget:
"""Use this as target value in clients for them to automatically connect
to the target exposed by the server. Servers must have only one target."""
pass
class RemoteError(Exception):
"""Raised when a RPC failed or raised an exception on the remote (server)
side."""
@ -42,6 +48,20 @@ class IncompatibleServer(Exception):
_init_string = b"ARTIQ pc_rpc\n"
def _validate_target_name(target_name, target_names):
if target_name is AutoTarget:
if len(target_names) > 1:
raise ValueError("Server has multiple targets: " +
" ".join(sorted(target_names)))
else:
target_name = target_names[0]
elif target_name not in target_names:
raise IncompatibleServer(
"valid target name(s): " +
" ".join(sorted(target_names)))
return target_name
class Client:
"""This class proxies the methods available on the server so that they
can be used as if they were local methods.
@ -67,11 +87,13 @@ class Client:
:param port: TCP port to use.
:param target_name: Target name to select. ``IncompatibleServer`` is
raised if the target does not exist.
Use ``AutoTarget`` for automatic selection if the server has only one
target.
Use ``None`` to skip selecting a target. The list of targets can then
be retrieved using ``get_rpc_id`` and then one can be selected later
using ``select_rpc_target``.
"""
def __init__(self, host, port, target_name):
def __init__(self, host, port, target_name=AutoTarget):
self.__socket = socket.create_connection((host, port))
try:
@ -89,8 +111,7 @@ class Client:
def select_rpc_target(self, target_name):
"""Selects a RPC target by name. This function should be called
exactly once if the object was created with ``target_name=None``."""
if target_name not in self.__target_names:
raise IncompatibleServer
target_name = _validate_target_name(target_name, self.__target_names)
self.__socket.sendall((target_name + "\n").encode())
def get_rpc_id(self):
@ -180,8 +201,7 @@ class AsyncioClient:
"""Selects a RPC target by name. This function should be called
exactly once if the connection was created with ``target_name=None``.
"""
if target_name not in self.__target_names:
raise IncompatibleServer
target_name = _validate_target_name(target_name, self.__target_names)
self.__writer.write((target_name + "\n").encode())
def get_rpc_id(self):
@ -259,7 +279,8 @@ class BestEffortClient:
except:
logger.warning("first connection attempt to %s:%d[%s] failed, "
"retrying in the background",
self.__host, self.__port, self.__target_name)
self.__host, self.__port, self.__target_name,
exc_info=True)
self.__start_conretry()
else:
self.__conretry_thread = None
@ -273,9 +294,9 @@ class BestEffortClient:
(self.__host, self.__port), timeout)
self.__socket.sendall(_init_string)
server_identification = self.__recv()
if self.__target_name not in server_identification["targets"]:
raise IncompatibleServer
self.__socket.sendall((self.__target_name + "\n").encode())
target_name = _validate_target_name(self.__target_name,
server_identification["targets"])
self.__socket.sendall((target_name + "\n").encode())
def __start_conretry(self):
self.__conretry_thread = threading.Thread(target=self.__conretry)