pc_rpc: raise AttributeError immediately for nonexistent RPC methods. Closes #534

This commit is contained in:
Sebastien Bourdeauducq 2016-09-14 11:22:07 +08:00
parent f010a74479
commit c7c8ad126f
6 changed files with 35 additions and 12 deletions

View File

@ -10,6 +10,9 @@ Release notes
ARTIQ_APPLET_EMBED. The GUI sets this enviroment variable itself and the
user simply needs to remove the --embed argument.
* EnvExperiment's prepare calls prepare for all its children.
* Dynamic __getattr__'s returning RPC target methods are not supported anymore.
Controller driver classes must define all their methods intended for RPC as
members.
2.0rc1

View File

@ -41,7 +41,7 @@ class Controller:
await remote.connect_rpc(self.host, self.port, None)
try:
targets, _ = remote.get_rpc_id()
remote.select_rpc_target(targets[0])
await remote.select_rpc_target(targets[0])
r = await getattr(remote, method)(*args, **kwargs)
finally:
remote.close_rpc()

View File

@ -1,5 +1,6 @@
import threading
import logging
import inspect
logger = logging.getLogger(__name__)
@ -20,6 +21,9 @@ class FFProxy:
"""
def __init__(self, target):
self.target = target
valid_methods = inspect.getmembers(target, inspect.ismethod)
self._valid_methods = {m[0] for m in valid_methods}
self._thread = None
def ff_join(self):
@ -28,6 +32,8 @@ class FFProxy:
self._thread.join()
def __getattr__(self, k):
if k not in self._valid_methods:
raise AttributeError
def run_in_thread(*args, **kwargs):
if self._thread is not None and self._thread.is_alive():
logger.warning("skipping fire-and-forget call to %r.%s as "

View File

@ -94,8 +94,6 @@ class Client:
in the middle of a RPC can break subsequent RPCs (from the same
client).
"""
kernel_invariants = set()
def __init__(self, host, port, target_name=AutoTarget, timeout=None):
self.__socket = socket.create_connection((host, port), timeout)
@ -106,6 +104,7 @@ class Client:
self.__target_names = server_identification["targets"]
self.__description = server_identification["description"]
self.__selected_target = None
self.__valid_methods = set()
if target_name is not None:
self.select_rpc_target(target_name)
except:
@ -118,6 +117,7 @@ class Client:
target_name = _validate_target_name(target_name, self.__target_names)
self.__socket.sendall((target_name + "\n").encode())
self.__selected_target = target_name
self.__valid_methods = self.__recv()
def get_selected_target(self):
"""Returns the selected target, or ``None`` if no target has been
@ -173,6 +173,8 @@ class Client:
return self.__do_action(obj)
def __getattr__(self, name):
if name not in self.__valid_methods:
raise AttributeError
def proxy(*args, **kwargs):
return self.__do_rpc(name, args, kwargs)
return proxy
@ -187,8 +189,6 @@ class AsyncioClient:
Concurrent access from different asyncio tasks is supported; all calls
use a single lock.
"""
kernel_invariants = set()
def __init__(self):
self.__lock = asyncio.Lock()
self.__reader = None
@ -208,19 +208,21 @@ class AsyncioClient:
self.__target_names = server_identification["targets"]
self.__description = server_identification["description"]
self.__selected_target = None
self.__valid_methods = set()
if target_name is not None:
self.select_rpc_target(target_name)
await self.select_rpc_target(target_name)
except:
self.close_rpc()
raise
def select_rpc_target(self, target_name):
async def select_rpc_target(self, target_name):
"""Selects a RPC target by name. This function should be called
exactly once if the connection was created with ``target_name=None``.
"""
target_name = _validate_target_name(target_name, self.__target_names)
self.__writer.write((target_name + "\n").encode())
self.__selected_target = target_name
self.__valid_methods = await self.__recv()
def get_selected_target(self):
"""Returns the selected target, or ``None`` if no target has been
@ -273,6 +275,8 @@ class AsyncioClient:
self.__lock.release()
def __getattr__(self, name):
if name not in self.__valid_methods:
raise AttributeError
async def proxy(*args, **kwargs):
res = await self.__do_rpc(name, args, kwargs)
return res
@ -292,8 +296,6 @@ class BestEffortClient:
:param retry: Amount of time to wait between retries when reconnecting
in the background.
"""
kernel_invariants = set()
def __init__(self, host, port, target_name,
firstcon_timeout=1.0, retry=5.0):
self.__host = host
@ -303,6 +305,7 @@ class BestEffortClient:
self.__conretry_terminate = False
self.__socket = None
self.__valid_methods = set()
try:
self.__coninit(firstcon_timeout)
except:
@ -327,6 +330,7 @@ class BestEffortClient:
target_name = _validate_target_name(self.__target_name,
server_identification["targets"])
self.__socket.sendall((target_name + "\n").encode())
self.__valid_methods = self.__recv()
def __start_conretry(self):
self.__conretry_thread = threading.Thread(target=self.__conretry)
@ -401,6 +405,8 @@ class BestEffortClient:
raise ValueError
def __getattr__(self, name):
if name not in self.__valid_methods:
raise AttributeError
def proxy(*args, **kwargs):
return self.__do_rpc(name, args, kwargs)
return proxy
@ -558,6 +564,12 @@ class Server(_AsyncioServer):
if callable(target):
target = target()
valid_methods = inspect.getmembers(target, inspect.ismethod)
valid_methods = {m[0] for m in valid_methods}
if self.builtin_terminate:
valid_methods.add("terminate")
writer.write((pyon.encode(valid_methods) + "\n").encode())
while True:
line = await reader.readline()
if not line:

View File

@ -33,7 +33,7 @@ class ControllerCase(unittest.TestCase):
remote = AsyncioClient()
await remote.connect_rpc(host, port, None)
targets, _ = remote.get_rpc_id()
remote.select_rpc_target(targets[0])
await remote.select_rpc_target(targets[0])
self.addCleanup(remote.close_rpc)
return remote

View File

@ -46,7 +46,7 @@ class RPCCase(unittest.TestCase):
test_object_back = remote.async_echo(test_object)
self.assertEqual(test_object, test_object_back)
with self.assertRaises(AttributeError):
remote.non_existing_method()
remote.non_existing_method
remote.terminate()
finally:
remote.close_rpc()
@ -73,7 +73,7 @@ class RPCCase(unittest.TestCase):
test_object_back = await remote.async_echo(test_object)
self.assertEqual(test_object, test_object_back)
with self.assertRaises(AttributeError):
await remote.non_existing_method()
await remote.non_existing_method
await remote.terminate()
finally:
remote.close_rpc()
@ -101,6 +101,8 @@ class FireAndForgetCase(unittest.TestCase):
self.ok = False
p = fire_and_forget.FFProxy(self)
p._set_ok()
with self.assertRaises(AttributeError):
p.non_existing_method
p.ff_join()
self.assertTrue(self.ok)