pc_rpc: fix handling of type annotations

This commit is contained in:
Drew 2019-01-09 23:13:22 -05:00 committed by Sébastien Bourdeauducq
parent 088530604e
commit 721c6f3bcc
2 changed files with 64 additions and 12 deletions

View File

@ -11,12 +11,12 @@ client passes a list as a parameter of an RPC method, and that method
client's list. client's list.
""" """
import socket
import asyncio import asyncio
import inspect
import logging
import socket
import threading import threading
import time import time
import logging
import inspect
from operator import itemgetter from operator import itemgetter
from artiq.monkey_patches import * from artiq.monkey_patches import *
@ -24,7 +24,6 @@ from artiq.protocols import pyon
from artiq.protocols.asyncio_server import AsyncioServer as _AsyncioServer from artiq.protocols.asyncio_server import AsyncioServer as _AsyncioServer
from artiq.protocols.packed_exceptions import * from artiq.protocols.packed_exceptions import *
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -487,6 +486,27 @@ class Server(_AsyncioServer):
else: else:
self._noparallel = asyncio.Lock() self._noparallel = asyncio.Lock()
@staticmethod
def _document_function(function):
"""
Turn a function into a tuple of its arguments and documentation.
Allows remote inspection of what methods are available on a local device.
Args:
function (Callable): a Python function to be documented.
Returns:
Tuple[dict, str]: tuple of (argument specifications,
function documentation).
Any type annotations are converted to strings (for PYON serialization).
"""
argspec_dict = dict(inspect.getfullargspec(function)._asdict())
# Fix issue #1186: PYON can't serialize type annotations.
if any(argspec_dict.get("annotations", {})):
argspec_dict["annotations"] = str(argspec_dict["annotations"])
return argspec_dict, inspect.getdoc(function)
async def _process_action(self, target, obj): async def _process_action(self, target, obj):
if self._noparallel is not None: if self._noparallel is not None:
await self._noparallel.acquire() await self._noparallel.acquire()
@ -501,9 +521,7 @@ class Server(_AsyncioServer):
if name.startswith("_"): if name.startswith("_"):
continue continue
method = getattr(target, name) method = getattr(target, name)
argspec = inspect.getfullargspec(method) doc["methods"][name] = self._document_function(method)
doc["methods"][name] = (dict(argspec._asdict()),
inspect.getdoc(method))
if self.builtin_terminate: if self.builtin_terminate:
doc["methods"]["terminate"] = ( doc["methods"]["terminate"] = (
{ {
@ -515,6 +533,7 @@ class Server(_AsyncioServer):
"kwonlydefaults": [], "kwonlydefaults": [],
}, },
"Terminate the server.") "Terminate the server.")
logger.debug("RPC docs for %s: %s", target, doc)
return {"status": "ok", "ret": doc} return {"status": "ok", "ret": doc}
elif obj["action"] == "call": elif obj["action"] == "call":
logger.debug("calling %s", _PrettyPrintCall(obj)) logger.debug("calling %s", _PrettyPrintCall(obj))

View File

@ -1,13 +1,13 @@
import unittest
import sys
import subprocess
import asyncio import asyncio
import inspect
import subprocess
import sys
import time import time
import unittest
import numpy as np import numpy as np
from artiq.protocols import pc_rpc, fire_and_forget from artiq.protocols import fire_and_forget, pc_rpc, pyon
test_address = "::1" test_address = "::1"
test_port = 7777 test_port = 7777
@ -92,6 +92,38 @@ class RPCCase(unittest.TestCase):
def test_asyncio_echo_autotarget(self): def test_asyncio_echo_autotarget(self):
self._run_server_and_test(self._loop_asyncio_echo, pc_rpc.AutoTarget) self._run_server_and_test(self._loop_asyncio_echo, pc_rpc.AutoTarget)
def test_rpc_encode_function(self):
"""Test that `pc_rpc` can encode a function properly.
Used in `get_rpc_method_list` part of
:meth:`artiq.protocols.pc_rpc.Server._process_action`
"""
def _annotated_function(
arg1: str, arg2: np.ndarray = np.array([1, 2])
) -> np.ndarray:
"""Sample docstring."""
return arg1
argspec_documented, docstring = pc_rpc.Server._document_function(
_annotated_function
)
print(argspec_documented)
self.assertEqual(docstring, "Sample docstring.")
# purposefully ignore how argspec["annotations"] is treated.
# allows option to change PYON later to encode annotations.
argspec_master = dict(inspect.getfullargspec(_annotated_function)._asdict())
argspec_without_annotation = argspec_master.copy()
del argspec_without_annotation["annotations"]
# check if all items (excluding annotations) are same in both dictionaries
self.assertLessEqual(
argspec_without_annotation.items(), argspec_documented.items()
)
self.assertDictEqual(
argspec_documented, pyon.decode(pyon.encode(argspec_documented))
)
class FireAndForgetCase(unittest.TestCase): class FireAndForgetCase(unittest.TestCase):
def _set_ok(self): def _set_ok(self):
@ -130,5 +162,6 @@ def run_server():
finally: finally:
loop.close() loop.close()
if __name__ == "__main__": if __name__ == "__main__":
run_server() run_server()