forked from M-Labs/artiq
compiler: improved rpc performance for list and array
1. Removed duplicated tags before each elements. 2. Use numpy functions to speedup parsing.
This commit is contained in:
parent
cfddc13294
commit
7181ff66a6
@ -43,9 +43,11 @@ class Reply(Enum):
|
|||||||
class UnsupportedDevice(Exception):
|
class UnsupportedDevice(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LoadError(Exception):
|
class LoadError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RPCReturnValueError(ValueError):
|
class RPCReturnValueError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -53,6 +55,105 @@ class RPCReturnValueError(ValueError):
|
|||||||
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
|
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
|
||||||
|
|
||||||
|
|
||||||
|
def _receive_fraction(kernel, embedding_map):
|
||||||
|
numerator = kernel._read_int64()
|
||||||
|
denominator = kernel._read_int64()
|
||||||
|
return Fraction(numerator, denominator)
|
||||||
|
|
||||||
|
|
||||||
|
def _receive_list(kernel, embedding_map):
|
||||||
|
length = kernel._read_int32()
|
||||||
|
tag = chr(kernel._read_int8())
|
||||||
|
if tag == "b":
|
||||||
|
buffer = kernel._read(length)
|
||||||
|
return numpy.ndarray((length, ), 'B', buffer).tolist()
|
||||||
|
elif tag == "i":
|
||||||
|
buffer = kernel._read(4 * length)
|
||||||
|
return numpy.ndarray((length, ), '>i4', buffer).tolist()
|
||||||
|
elif tag == "I":
|
||||||
|
buffer = kernel._read(8 * length)
|
||||||
|
return numpy.ndarray((length, ), '>i8', buffer).tolist()
|
||||||
|
elif tag == "f":
|
||||||
|
buffer = kernel._read(8 * length)
|
||||||
|
return numpy.ndarray((length, ), '>d', buffer).tolist()
|
||||||
|
else:
|
||||||
|
fn = receivers[tag]
|
||||||
|
elems = []
|
||||||
|
for _ in range(length):
|
||||||
|
# discard tag, as our device would still send the tag for each
|
||||||
|
# non-primitive elements.
|
||||||
|
kernel._read_int8()
|
||||||
|
item = fn(kernel, embedding_map)
|
||||||
|
elems.append(item)
|
||||||
|
return elems
|
||||||
|
|
||||||
|
|
||||||
|
def _receive_array(kernel, embedding_map):
|
||||||
|
num_dims = kernel._read_int8()
|
||||||
|
shape = tuple(kernel._read_int32() for _ in range(num_dims))
|
||||||
|
tag = chr(kernel._read_int8())
|
||||||
|
fn = receivers[tag]
|
||||||
|
length = numpy.prod(shape)
|
||||||
|
if tag == "b":
|
||||||
|
buffer = kernel._read(length)
|
||||||
|
elems = numpy.ndarray((length, ), 'B', buffer)
|
||||||
|
elif tag == "i":
|
||||||
|
buffer = kernel._read(4 * length)
|
||||||
|
elems = numpy.ndarray((length, ), '>i4', buffer)
|
||||||
|
elif tag == "I":
|
||||||
|
buffer = kernel._read(8 * length)
|
||||||
|
elems = numpy.ndarray((length, ), '>i8', buffer)
|
||||||
|
elif tag == "f":
|
||||||
|
buffer = kernel._read(8 * length)
|
||||||
|
elems = numpy.ndarray((length, ), '>d', buffer)
|
||||||
|
else:
|
||||||
|
fn = receivers[tag]
|
||||||
|
elems = []
|
||||||
|
for _ in range(numpy.prod(shape)):
|
||||||
|
# discard the tag
|
||||||
|
kernel._read_int8()
|
||||||
|
item = fn(kernel, embedding_map)
|
||||||
|
elems.append(item)
|
||||||
|
elems = numpy.array(elems)
|
||||||
|
return elems.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _receive_range(kernel, embedding_map):
|
||||||
|
start = kernel._receive_rpc_value(embedding_map)
|
||||||
|
stop = kernel._receive_rpc_value(embedding_map)
|
||||||
|
step = kernel._receive_rpc_value(embedding_map)
|
||||||
|
return range(start, stop, step)
|
||||||
|
|
||||||
|
|
||||||
|
def _receive_keyword(kernel, embedding_map):
|
||||||
|
name = kernel._read_string()
|
||||||
|
value = kernel._receive_rpc_value(embedding_map)
|
||||||
|
return RPCKeyword(name, value)
|
||||||
|
|
||||||
|
|
||||||
|
receivers = {
|
||||||
|
"\x00": lambda kernel, embedding_map: kernel._rpc_sentinel,
|
||||||
|
"t": lambda kernel, embedding_map:
|
||||||
|
tuple(kernel._receive_rpc_value(embedding_map)
|
||||||
|
for _ in range(kernel._read_int8())),
|
||||||
|
"n": lambda kernel, embedding_map: None,
|
||||||
|
"b": lambda kernel, embedding_map: bool(kernel._read_int8()),
|
||||||
|
"i": lambda kernel, embedding_map: numpy.int32(kernel._read_int32()),
|
||||||
|
"I": lambda kernel, embedding_map: numpy.int32(kernel._read_int64()),
|
||||||
|
"f": lambda kernel, embedding_map: kernel._read_float64(),
|
||||||
|
"s": lambda kernel, embedding_map: kernel._read_string(),
|
||||||
|
"B": lambda kernel, embedding_map: kernel._read_bytes(),
|
||||||
|
"A": lambda kernel, embedding_map: kernel._read_bytes(),
|
||||||
|
"O": lambda kernel, embedding_map:
|
||||||
|
embedding_map.retrieve_object(kernel._read_int32()),
|
||||||
|
"F": _receive_fraction,
|
||||||
|
"l": _receive_list,
|
||||||
|
"a": _receive_array,
|
||||||
|
"r": _receive_range,
|
||||||
|
"k": _receive_keyword
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class CommKernelDummy:
|
class CommKernelDummy:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@ -247,50 +348,8 @@ class CommKernel:
|
|||||||
# See rpc_proto.rs and compiler/ir.py:rpc_tag.
|
# See rpc_proto.rs and compiler/ir.py:rpc_tag.
|
||||||
def _receive_rpc_value(self, embedding_map):
|
def _receive_rpc_value(self, embedding_map):
|
||||||
tag = chr(self._read_int8())
|
tag = chr(self._read_int8())
|
||||||
if tag == "\x00":
|
if tag in receivers:
|
||||||
return self._rpc_sentinel
|
return receivers.get(tag)(self, embedding_map)
|
||||||
elif tag == "t":
|
|
||||||
length = self._read_int8()
|
|
||||||
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
|
|
||||||
elif tag == "n":
|
|
||||||
return None
|
|
||||||
elif tag == "b":
|
|
||||||
return bool(self._read_int8())
|
|
||||||
elif tag == "i":
|
|
||||||
return numpy.int32(self._read_int32())
|
|
||||||
elif tag == "I":
|
|
||||||
return numpy.int64(self._read_int64())
|
|
||||||
elif tag == "f":
|
|
||||||
return self._read_float64()
|
|
||||||
elif tag == "F":
|
|
||||||
numerator = self._read_int64()
|
|
||||||
denominator = self._read_int64()
|
|
||||||
return Fraction(numerator, denominator)
|
|
||||||
elif tag == "s":
|
|
||||||
return self._read_string()
|
|
||||||
elif tag == "B":
|
|
||||||
return self._read_bytes()
|
|
||||||
elif tag == "A":
|
|
||||||
return self._read_bytes()
|
|
||||||
elif tag == "l":
|
|
||||||
length = self._read_int32()
|
|
||||||
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
|
|
||||||
elif tag == "a":
|
|
||||||
num_dims = self._read_int8()
|
|
||||||
shape = tuple(self._read_int32() for _ in range(num_dims))
|
|
||||||
elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))]
|
|
||||||
return numpy.array(elems).reshape(shape)
|
|
||||||
elif tag == "r":
|
|
||||||
start = self._receive_rpc_value(embedding_map)
|
|
||||||
stop = self._receive_rpc_value(embedding_map)
|
|
||||||
step = self._receive_rpc_value(embedding_map)
|
|
||||||
return range(start, stop, step)
|
|
||||||
elif tag == "k":
|
|
||||||
name = self._read_string()
|
|
||||||
value = self._receive_rpc_value(embedding_map)
|
|
||||||
return RPCKeyword(name, value)
|
|
||||||
elif tag == "O":
|
|
||||||
return embedding_map.retrieve_object(self._read_int32())
|
|
||||||
else:
|
else:
|
||||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||||
|
|
||||||
@ -378,6 +437,19 @@ class CommKernel:
|
|||||||
check(isinstance(value, list),
|
check(isinstance(value, list),
|
||||||
lambda: "list")
|
lambda: "list")
|
||||||
self._write_int32(len(value))
|
self._write_int32(len(value))
|
||||||
|
tag_element = chr(tags[0])
|
||||||
|
if tag_element == "b":
|
||||||
|
self._write(bytes(value))
|
||||||
|
elif tag_element == "i":
|
||||||
|
array = numpy.array(value, '>i4')
|
||||||
|
self._write(array.tobytes())
|
||||||
|
elif tag_element == "I":
|
||||||
|
array = numpy.array(value, '>i8')
|
||||||
|
self._write(array.tobytes())
|
||||||
|
elif tag_element == "f":
|
||||||
|
array = numpy.array(value, '>d')
|
||||||
|
self._write(array.tobytes())
|
||||||
|
else:
|
||||||
for elt in value:
|
for elt in value:
|
||||||
tags_copy = bytearray(tags)
|
tags_copy = bytearray(tags)
|
||||||
self._send_rpc_value(tags_copy, elt, root, function)
|
self._send_rpc_value(tags_copy, elt, root, function)
|
||||||
@ -390,6 +462,19 @@ class CommKernel:
|
|||||||
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
|
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
|
||||||
for s in value.shape:
|
for s in value.shape:
|
||||||
self._write_int32(s)
|
self._write_int32(s)
|
||||||
|
tag_element = chr(tags[0])
|
||||||
|
if tag_element == "b":
|
||||||
|
self._write(value.reshape((-1,), order="C").tobytes())
|
||||||
|
elif tag_element == "i":
|
||||||
|
array = value.reshape((-1,), order="C").astype('>i4')
|
||||||
|
self._write(array.tobytes())
|
||||||
|
elif tag_element == "I":
|
||||||
|
array = value.reshape((-1,), order="C").astype('>i8')
|
||||||
|
self._write(array.tobytes())
|
||||||
|
elif tag_element == "f":
|
||||||
|
array = value.reshape((-1,), order="C").astype('>d')
|
||||||
|
self._write(array.tobytes())
|
||||||
|
else:
|
||||||
for elt in value.reshape((-1,), order="C"):
|
for elt in value.reshape((-1,), order="C"):
|
||||||
tags_copy = bytearray(tags)
|
tags_copy = bytearray(tags)
|
||||||
self._send_rpc_value(tags_copy, elt, root, function)
|
self._send_rpc_value(tags_copy, elt, root, function)
|
||||||
@ -420,7 +505,7 @@ class CommKernel:
|
|||||||
return_tags = self._read_bytes()
|
return_tags = self._read_bytes()
|
||||||
|
|
||||||
if service_id == 0:
|
if service_id == 0:
|
||||||
service = lambda obj, attr, value: setattr(obj, attr, value)
|
def service(obj, attr, value): return setattr(obj, attr, value)
|
||||||
else:
|
else:
|
||||||
service = embedding_map.retrieve_object(service_id)
|
service = embedding_map.retrieve_object(service_id)
|
||||||
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
|
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
|
||||||
@ -432,15 +517,18 @@ class CommKernel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = service(*args, **kwargs)
|
result = service(*args, **kwargs)
|
||||||
logger.debug("rpc service: %d %r %r = %r", service_id, args, kwargs, result)
|
logger.debug("rpc service: %d %r %r = %r",
|
||||||
|
service_id, args, kwargs, result)
|
||||||
|
|
||||||
self._write_header(Request.RPCReply)
|
self._write_header(Request.RPCReply)
|
||||||
self._write_bytes(return_tags)
|
self._write_bytes(return_tags)
|
||||||
self._send_rpc_value(bytearray(return_tags), result, result, service)
|
self._send_rpc_value(bytearray(return_tags),
|
||||||
|
result, result, service)
|
||||||
except RPCReturnValueError as exn:
|
except RPCReturnValueError as exn:
|
||||||
raise
|
raise
|
||||||
except Exception as exn:
|
except Exception as exn:
|
||||||
logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn)
|
logger.debug("rpc service: %d %r %r ! %r",
|
||||||
|
service_id, args, kwargs, exn)
|
||||||
|
|
||||||
self._write_header(Request.RPCException)
|
self._write_header(Request.RPCException)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user