mirror of https://github.com/m-labs/artiq.git
pc_rpc: numpy-compatible serialization
This commit is contained in:
parent
74856c151b
commit
16170c9013
|
@ -1,8 +1,9 @@
|
|||
import socket
|
||||
import json
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from artiq.management import pyon
|
||||
|
||||
|
||||
class RemoteError(Exception):
|
||||
pass
|
||||
|
@ -17,7 +18,7 @@ class Client:
|
|||
|
||||
def do_rpc(self, name, args, kwargs):
|
||||
obj = {"action": "call", "name": name, "args": args, "kwargs": kwargs}
|
||||
line = json.dumps(obj) + "\n"
|
||||
line = pyon.encode(obj) + "\n"
|
||||
self.socket.sendall(line.encode())
|
||||
|
||||
buf = self.socket.recv(4096).decode()
|
||||
|
@ -26,7 +27,7 @@ class Client:
|
|||
if not more:
|
||||
break
|
||||
buf += more.decode()
|
||||
obj = json.loads(buf)
|
||||
obj = pyon.decode(buf)
|
||||
if obj["result"] == "ok":
|
||||
return obj["ret"]
|
||||
elif obj["result"] == "error":
|
||||
|
@ -73,7 +74,7 @@ class Server:
|
|||
line = yield from reader.readline()
|
||||
if not line:
|
||||
break
|
||||
obj = json.loads(line.decode())
|
||||
obj = pyon.decode(line.decode())
|
||||
action = obj["action"]
|
||||
if action == "call":
|
||||
method = getattr(self.target, obj["name"])
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
import base64
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def _encode_none(x):
|
||||
return "None"
|
||||
|
||||
|
||||
def _encode_number(x):
|
||||
return str(x)
|
||||
|
||||
|
||||
def _encode_str(x):
|
||||
return repr(x)
|
||||
|
||||
|
||||
def _encode_tuple(x):
|
||||
if len(x) == 1:
|
||||
return "(" + encode(x[0]) + ", )"
|
||||
else:
|
||||
r = "("
|
||||
r += ", ".join([encode(item) for item in x])
|
||||
r += ")"
|
||||
return r
|
||||
|
||||
|
||||
def _encode_list(x):
|
||||
r = "["
|
||||
r += ", ".join([encode(item) for item in x])
|
||||
r += "]"
|
||||
return r
|
||||
|
||||
|
||||
def _encode_dict(x):
|
||||
r = "{"
|
||||
r += ", ".join([encode(k) + ": " + encode(v) for k, v in x.items()])
|
||||
r += "}"
|
||||
return r
|
||||
|
||||
|
||||
def _encode_nparray(x):
|
||||
r = "nparray("
|
||||
r += encode(x.shape) + ", "
|
||||
r += encode(str(x.dtype)) + ", "
|
||||
r += encode(base64.b64encode(x).decode())
|
||||
r += ")"
|
||||
return r
|
||||
|
||||
|
||||
_encode_map = {
|
||||
type(None): _encode_none,
|
||||
int: _encode_number,
|
||||
float: _encode_number,
|
||||
str: _encode_str,
|
||||
tuple: _encode_tuple,
|
||||
list: _encode_list,
|
||||
dict: _encode_dict,
|
||||
numpy.ndarray: _encode_nparray
|
||||
}
|
||||
|
||||
|
||||
def encode(x):
|
||||
return _encode_map[type(x)](x)
|
||||
|
||||
|
||||
def _nparray(shape, dtype, data):
|
||||
a = numpy.frombuffer(base64.b64decode(data), dtype=dtype)
|
||||
return a.reshape(shape)
|
||||
|
||||
|
||||
def decode(s):
|
||||
return eval(s, {"__builtins__": None, "nparray": _nparray}, {})
|
Loading…
Reference in New Issue