forked from M-Labs/artiq
pc_rpc: raise AttributeError immediately for nonexistent RPC methods. Closes #534
This commit is contained in:
parent
f010a74479
commit
c7c8ad126f
@ -10,6 +10,9 @@ Release notes
|
||||
ARTIQ_APPLET_EMBED. The GUI sets this enviroment variable itself and the
|
||||
user simply needs to remove the --embed argument.
|
||||
* EnvExperiment's prepare calls prepare for all its children.
|
||||
* Dynamic __getattr__'s returning RPC target methods are not supported anymore.
|
||||
Controller driver classes must define all their methods intended for RPC as
|
||||
members.
|
||||
|
||||
|
||||
2.0rc1
|
||||
|
@ -41,7 +41,7 @@ class Controller:
|
||||
await remote.connect_rpc(self.host, self.port, None)
|
||||
try:
|
||||
targets, _ = remote.get_rpc_id()
|
||||
remote.select_rpc_target(targets[0])
|
||||
await remote.select_rpc_target(targets[0])
|
||||
r = await getattr(remote, method)(*args, **kwargs)
|
||||
finally:
|
||||
remote.close_rpc()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import threading
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -20,6 +21,9 @@ class FFProxy:
|
||||
"""
|
||||
def __init__(self, target):
|
||||
self.target = target
|
||||
|
||||
valid_methods = inspect.getmembers(target, inspect.ismethod)
|
||||
self._valid_methods = {m[0] for m in valid_methods}
|
||||
self._thread = None
|
||||
|
||||
def ff_join(self):
|
||||
@ -28,6 +32,8 @@ class FFProxy:
|
||||
self._thread.join()
|
||||
|
||||
def __getattr__(self, k):
|
||||
if k not in self._valid_methods:
|
||||
raise AttributeError
|
||||
def run_in_thread(*args, **kwargs):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning("skipping fire-and-forget call to %r.%s as "
|
||||
|
@ -94,8 +94,6 @@ class Client:
|
||||
in the middle of a RPC can break subsequent RPCs (from the same
|
||||
client).
|
||||
"""
|
||||
kernel_invariants = set()
|
||||
|
||||
def __init__(self, host, port, target_name=AutoTarget, timeout=None):
|
||||
self.__socket = socket.create_connection((host, port), timeout)
|
||||
|
||||
@ -106,6 +104,7 @@ class Client:
|
||||
self.__target_names = server_identification["targets"]
|
||||
self.__description = server_identification["description"]
|
||||
self.__selected_target = None
|
||||
self.__valid_methods = set()
|
||||
if target_name is not None:
|
||||
self.select_rpc_target(target_name)
|
||||
except:
|
||||
@ -118,6 +117,7 @@ class Client:
|
||||
target_name = _validate_target_name(target_name, self.__target_names)
|
||||
self.__socket.sendall((target_name + "\n").encode())
|
||||
self.__selected_target = target_name
|
||||
self.__valid_methods = self.__recv()
|
||||
|
||||
def get_selected_target(self):
|
||||
"""Returns the selected target, or ``None`` if no target has been
|
||||
@ -173,6 +173,8 @@ class Client:
|
||||
return self.__do_action(obj)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name not in self.__valid_methods:
|
||||
raise AttributeError
|
||||
def proxy(*args, **kwargs):
|
||||
return self.__do_rpc(name, args, kwargs)
|
||||
return proxy
|
||||
@ -187,8 +189,6 @@ class AsyncioClient:
|
||||
Concurrent access from different asyncio tasks is supported; all calls
|
||||
use a single lock.
|
||||
"""
|
||||
kernel_invariants = set()
|
||||
|
||||
def __init__(self):
|
||||
self.__lock = asyncio.Lock()
|
||||
self.__reader = None
|
||||
@ -208,19 +208,21 @@ class AsyncioClient:
|
||||
self.__target_names = server_identification["targets"]
|
||||
self.__description = server_identification["description"]
|
||||
self.__selected_target = None
|
||||
self.__valid_methods = set()
|
||||
if target_name is not None:
|
||||
self.select_rpc_target(target_name)
|
||||
await self.select_rpc_target(target_name)
|
||||
except:
|
||||
self.close_rpc()
|
||||
raise
|
||||
|
||||
def select_rpc_target(self, target_name):
|
||||
async def select_rpc_target(self, target_name):
|
||||
"""Selects a RPC target by name. This function should be called
|
||||
exactly once if the connection was created with ``target_name=None``.
|
||||
"""
|
||||
target_name = _validate_target_name(target_name, self.__target_names)
|
||||
self.__writer.write((target_name + "\n").encode())
|
||||
self.__selected_target = target_name
|
||||
self.__valid_methods = await self.__recv()
|
||||
|
||||
def get_selected_target(self):
|
||||
"""Returns the selected target, or ``None`` if no target has been
|
||||
@ -273,6 +275,8 @@ class AsyncioClient:
|
||||
self.__lock.release()
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name not in self.__valid_methods:
|
||||
raise AttributeError
|
||||
async def proxy(*args, **kwargs):
|
||||
res = await self.__do_rpc(name, args, kwargs)
|
||||
return res
|
||||
@ -292,8 +296,6 @@ class BestEffortClient:
|
||||
:param retry: Amount of time to wait between retries when reconnecting
|
||||
in the background.
|
||||
"""
|
||||
kernel_invariants = set()
|
||||
|
||||
def __init__(self, host, port, target_name,
|
||||
firstcon_timeout=1.0, retry=5.0):
|
||||
self.__host = host
|
||||
@ -303,6 +305,7 @@ class BestEffortClient:
|
||||
|
||||
self.__conretry_terminate = False
|
||||
self.__socket = None
|
||||
self.__valid_methods = set()
|
||||
try:
|
||||
self.__coninit(firstcon_timeout)
|
||||
except:
|
||||
@ -327,6 +330,7 @@ class BestEffortClient:
|
||||
target_name = _validate_target_name(self.__target_name,
|
||||
server_identification["targets"])
|
||||
self.__socket.sendall((target_name + "\n").encode())
|
||||
self.__valid_methods = self.__recv()
|
||||
|
||||
def __start_conretry(self):
|
||||
self.__conretry_thread = threading.Thread(target=self.__conretry)
|
||||
@ -401,6 +405,8 @@ class BestEffortClient:
|
||||
raise ValueError
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name not in self.__valid_methods:
|
||||
raise AttributeError
|
||||
def proxy(*args, **kwargs):
|
||||
return self.__do_rpc(name, args, kwargs)
|
||||
return proxy
|
||||
@ -558,6 +564,12 @@ class Server(_AsyncioServer):
|
||||
if callable(target):
|
||||
target = target()
|
||||
|
||||
valid_methods = inspect.getmembers(target, inspect.ismethod)
|
||||
valid_methods = {m[0] for m in valid_methods}
|
||||
if self.builtin_terminate:
|
||||
valid_methods.add("terminate")
|
||||
writer.write((pyon.encode(valid_methods) + "\n").encode())
|
||||
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
|
@ -33,7 +33,7 @@ class ControllerCase(unittest.TestCase):
|
||||
remote = AsyncioClient()
|
||||
await remote.connect_rpc(host, port, None)
|
||||
targets, _ = remote.get_rpc_id()
|
||||
remote.select_rpc_target(targets[0])
|
||||
await remote.select_rpc_target(targets[0])
|
||||
self.addCleanup(remote.close_rpc)
|
||||
return remote
|
||||
|
||||
|
@ -46,7 +46,7 @@ class RPCCase(unittest.TestCase):
|
||||
test_object_back = remote.async_echo(test_object)
|
||||
self.assertEqual(test_object, test_object_back)
|
||||
with self.assertRaises(AttributeError):
|
||||
remote.non_existing_method()
|
||||
remote.non_existing_method
|
||||
remote.terminate()
|
||||
finally:
|
||||
remote.close_rpc()
|
||||
@ -73,7 +73,7 @@ class RPCCase(unittest.TestCase):
|
||||
test_object_back = await remote.async_echo(test_object)
|
||||
self.assertEqual(test_object, test_object_back)
|
||||
with self.assertRaises(AttributeError):
|
||||
await remote.non_existing_method()
|
||||
await remote.non_existing_method
|
||||
await remote.terminate()
|
||||
finally:
|
||||
remote.close_rpc()
|
||||
@ -101,6 +101,8 @@ class FireAndForgetCase(unittest.TestCase):
|
||||
self.ok = False
|
||||
p = fire_and_forget.FFProxy(self)
|
||||
p._set_ok()
|
||||
with self.assertRaises(AttributeError):
|
||||
p.non_existing_method
|
||||
p.ff_join()
|
||||
self.assertTrue(self.ok)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user