pc_rpc: autotarget support

This commit is contained in:
Sebastien Bourdeauducq 2015-10-18 14:34:30 +08:00
parent 661b9bfbfa
commit 5947f54855
2 changed files with 32 additions and 21 deletions

View File

@ -6,7 +6,7 @@ import sys
import numpy as np # Needed to use numpy in RPC call arguments on cmd line import numpy as np # Needed to use numpy in RPC call arguments on cmd line
import pprint import pprint
from artiq.protocols.pc_rpc import Client from artiq.protocols.pc_rpc import AutoTarget, Client
def get_argparser(): def get_argparser():
@ -85,19 +85,9 @@ def main():
args = get_argparser().parse_args() args = get_argparser().parse_args()
remote = Client(args.server, args.port, None) remote = Client(args.server, args.port, None)
targets, description = remote.get_rpc_id() targets, description = remote.get_rpc_id()
if args.action != "list-targets": if args.action != "list-targets":
# If no target specified and remote has only one, then use this one. remote.select_rpc_target(AutoTarget)
# Exit otherwise.
if len(targets) > 1 and args.target is None:
print("Remote server has several targets, please supply one with "
"-t")
sys.exit(1)
elif args.target is None:
args.target = targets[0]
remote.select_rpc_target(args.target)
if args.action == "list-targets": if args.action == "list-targets":
list_targets(targets, description) list_targets(targets, description)

View File

@ -27,6 +27,12 @@ from artiq.protocols.asyncio_server import AsyncioServer as _AsyncioServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoTarget:
"""Use this as target value in clients for them to automatically connect
to the target exposed by the server. Servers must have only one target."""
pass
class RemoteError(Exception): class RemoteError(Exception):
"""Raised when a RPC failed or raised an exception on the remote (server) """Raised when a RPC failed or raised an exception on the remote (server)
side.""" side."""
@ -42,6 +48,20 @@ class IncompatibleServer(Exception):
_init_string = b"ARTIQ pc_rpc\n" _init_string = b"ARTIQ pc_rpc\n"
def _validate_target_name(target_name, target_names):
if target_name is AutoTarget:
if len(target_names) > 1:
raise ValueError("Server has multiple targets: " +
" ".join(sorted(target_names)))
else:
target_name = target_names[0]
elif target_name not in target_names:
raise IncompatibleServer(
"valid target name(s): " +
" ".join(sorted(target_names)))
return target_name
class Client: class Client:
"""This class proxies the methods available on the server so that they """This class proxies the methods available on the server so that they
can be used as if they were local methods. can be used as if they were local methods.
@ -67,11 +87,13 @@ class Client:
:param port: TCP port to use. :param port: TCP port to use.
:param target_name: Target name to select. ``IncompatibleServer`` is :param target_name: Target name to select. ``IncompatibleServer`` is
raised if the target does not exist. raised if the target does not exist.
Use ``AutoTarget`` for automatic selection if the server has only one
target.
Use ``None`` to skip selecting a target. The list of targets can then Use ``None`` to skip selecting a target. The list of targets can then
be retrieved using ``get_rpc_id`` and then one can be selected later be retrieved using ``get_rpc_id`` and then one can be selected later
using ``select_rpc_target``. using ``select_rpc_target``.
""" """
def __init__(self, host, port, target_name): def __init__(self, host, port, target_name=AutoTarget):
self.__socket = socket.create_connection((host, port)) self.__socket = socket.create_connection((host, port))
try: try:
@ -89,8 +111,7 @@ class Client:
def select_rpc_target(self, target_name): def select_rpc_target(self, target_name):
"""Selects a RPC target by name. This function should be called """Selects a RPC target by name. This function should be called
exactly once if the object was created with ``target_name=None``.""" exactly once if the object was created with ``target_name=None``."""
if target_name not in self.__target_names: target_name = _validate_target_name(target_name, self.__target_names)
raise IncompatibleServer
self.__socket.sendall((target_name + "\n").encode()) self.__socket.sendall((target_name + "\n").encode())
def get_rpc_id(self): def get_rpc_id(self):
@ -180,8 +201,7 @@ class AsyncioClient:
"""Selects a RPC target by name. This function should be called """Selects a RPC target by name. This function should be called
exactly once if the connection was created with ``target_name=None``. exactly once if the connection was created with ``target_name=None``.
""" """
if target_name not in self.__target_names: target_name = _validate_target_name(target_name, self.__target_names)
raise IncompatibleServer
self.__writer.write((target_name + "\n").encode()) self.__writer.write((target_name + "\n").encode())
def get_rpc_id(self): def get_rpc_id(self):
@ -259,7 +279,8 @@ class BestEffortClient:
except: except:
logger.warning("first connection attempt to %s:%d[%s] failed, " logger.warning("first connection attempt to %s:%d[%s] failed, "
"retrying in the background", "retrying in the background",
self.__host, self.__port, self.__target_name) self.__host, self.__port, self.__target_name,
exc_info=True)
self.__start_conretry() self.__start_conretry()
else: else:
self.__conretry_thread = None self.__conretry_thread = None
@ -273,9 +294,9 @@ class BestEffortClient:
(self.__host, self.__port), timeout) (self.__host, self.__port), timeout)
self.__socket.sendall(_init_string) self.__socket.sendall(_init_string)
server_identification = self.__recv() server_identification = self.__recv()
if self.__target_name not in server_identification["targets"]: target_name = _validate_target_name(self.__target_name,
raise IncompatibleServer server_identification["targets"])
self.__socket.sendall((self.__target_name + "\n").encode()) self.__socket.sendall((target_name + "\n").encode())
def __start_conretry(self): def __start_conretry(self):
self.__conretry_thread = threading.Thread(target=self.__conretry) self.__conretry_thread = threading.Thread(target=self.__conretry)