forked from M-Labs/artiq
compiler: improved rpc performance for list and array
1. Removed duplicated tags before each elements. 2. Use numpy functions to speedup parsing.
This commit is contained in:
parent
cfddc13294
commit
7181ff66a6
|
@ -43,9 +43,11 @@ class Reply(Enum):
|
|||
class UnsupportedDevice(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LoadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RPCReturnValueError(ValueError):
|
||||
pass
|
||||
|
||||
|
@ -53,6 +55,105 @@ class RPCReturnValueError(ValueError):
|
|||
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
|
||||
|
||||
|
||||
def _receive_fraction(kernel, embedding_map):
|
||||
numerator = kernel._read_int64()
|
||||
denominator = kernel._read_int64()
|
||||
return Fraction(numerator, denominator)
|
||||
|
||||
|
||||
def _receive_list(kernel, embedding_map):
|
||||
length = kernel._read_int32()
|
||||
tag = chr(kernel._read_int8())
|
||||
if tag == "b":
|
||||
buffer = kernel._read(length)
|
||||
return numpy.ndarray((length, ), 'B', buffer).tolist()
|
||||
elif tag == "i":
|
||||
buffer = kernel._read(4 * length)
|
||||
return numpy.ndarray((length, ), '>i4', buffer).tolist()
|
||||
elif tag == "I":
|
||||
buffer = kernel._read(8 * length)
|
||||
return numpy.ndarray((length, ), '>i8', buffer).tolist()
|
||||
elif tag == "f":
|
||||
buffer = kernel._read(8 * length)
|
||||
return numpy.ndarray((length, ), '>d', buffer).tolist()
|
||||
else:
|
||||
fn = receivers[tag]
|
||||
elems = []
|
||||
for _ in range(length):
|
||||
# discard tag, as our device would still send the tag for each
|
||||
# non-primitive elements.
|
||||
kernel._read_int8()
|
||||
item = fn(kernel, embedding_map)
|
||||
elems.append(item)
|
||||
return elems
|
||||
|
||||
|
||||
def _receive_array(kernel, embedding_map):
|
||||
num_dims = kernel._read_int8()
|
||||
shape = tuple(kernel._read_int32() for _ in range(num_dims))
|
||||
tag = chr(kernel._read_int8())
|
||||
fn = receivers[tag]
|
||||
length = numpy.prod(shape)
|
||||
if tag == "b":
|
||||
buffer = kernel._read(length)
|
||||
elems = numpy.ndarray((length, ), 'B', buffer)
|
||||
elif tag == "i":
|
||||
buffer = kernel._read(4 * length)
|
||||
elems = numpy.ndarray((length, ), '>i4', buffer)
|
||||
elif tag == "I":
|
||||
buffer = kernel._read(8 * length)
|
||||
elems = numpy.ndarray((length, ), '>i8', buffer)
|
||||
elif tag == "f":
|
||||
buffer = kernel._read(8 * length)
|
||||
elems = numpy.ndarray((length, ), '>d', buffer)
|
||||
else:
|
||||
fn = receivers[tag]
|
||||
elems = []
|
||||
for _ in range(numpy.prod(shape)):
|
||||
# discard the tag
|
||||
kernel._read_int8()
|
||||
item = fn(kernel, embedding_map)
|
||||
elems.append(item)
|
||||
elems = numpy.array(elems)
|
||||
return elems.reshape(shape)
|
||||
|
||||
|
||||
def _receive_range(kernel, embedding_map):
|
||||
start = kernel._receive_rpc_value(embedding_map)
|
||||
stop = kernel._receive_rpc_value(embedding_map)
|
||||
step = kernel._receive_rpc_value(embedding_map)
|
||||
return range(start, stop, step)
|
||||
|
||||
|
||||
def _receive_keyword(kernel, embedding_map):
|
||||
name = kernel._read_string()
|
||||
value = kernel._receive_rpc_value(embedding_map)
|
||||
return RPCKeyword(name, value)
|
||||
|
||||
|
||||
receivers = {
|
||||
"\x00": lambda kernel, embedding_map: kernel._rpc_sentinel,
|
||||
"t": lambda kernel, embedding_map:
|
||||
tuple(kernel._receive_rpc_value(embedding_map)
|
||||
for _ in range(kernel._read_int8())),
|
||||
"n": lambda kernel, embedding_map: None,
|
||||
"b": lambda kernel, embedding_map: bool(kernel._read_int8()),
|
||||
"i": lambda kernel, embedding_map: numpy.int32(kernel._read_int32()),
|
||||
"I": lambda kernel, embedding_map: numpy.int32(kernel._read_int64()),
|
||||
"f": lambda kernel, embedding_map: kernel._read_float64(),
|
||||
"s": lambda kernel, embedding_map: kernel._read_string(),
|
||||
"B": lambda kernel, embedding_map: kernel._read_bytes(),
|
||||
"A": lambda kernel, embedding_map: kernel._read_bytes(),
|
||||
"O": lambda kernel, embedding_map:
|
||||
embedding_map.retrieve_object(kernel._read_int32()),
|
||||
"F": _receive_fraction,
|
||||
"l": _receive_list,
|
||||
"a": _receive_array,
|
||||
"r": _receive_range,
|
||||
"k": _receive_keyword
|
||||
}
|
||||
|
||||
|
||||
class CommKernelDummy:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -247,50 +348,8 @@ class CommKernel:
|
|||
# See rpc_proto.rs and compiler/ir.py:rpc_tag.
|
||||
def _receive_rpc_value(self, embedding_map):
|
||||
tag = chr(self._read_int8())
|
||||
if tag == "\x00":
|
||||
return self._rpc_sentinel
|
||||
elif tag == "t":
|
||||
length = self._read_int8()
|
||||
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
|
||||
elif tag == "n":
|
||||
return None
|
||||
elif tag == "b":
|
||||
return bool(self._read_int8())
|
||||
elif tag == "i":
|
||||
return numpy.int32(self._read_int32())
|
||||
elif tag == "I":
|
||||
return numpy.int64(self._read_int64())
|
||||
elif tag == "f":
|
||||
return self._read_float64()
|
||||
elif tag == "F":
|
||||
numerator = self._read_int64()
|
||||
denominator = self._read_int64()
|
||||
return Fraction(numerator, denominator)
|
||||
elif tag == "s":
|
||||
return self._read_string()
|
||||
elif tag == "B":
|
||||
return self._read_bytes()
|
||||
elif tag == "A":
|
||||
return self._read_bytes()
|
||||
elif tag == "l":
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
|
||||
elif tag == "a":
|
||||
num_dims = self._read_int8()
|
||||
shape = tuple(self._read_int32() for _ in range(num_dims))
|
||||
elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))]
|
||||
return numpy.array(elems).reshape(shape)
|
||||
elif tag == "r":
|
||||
start = self._receive_rpc_value(embedding_map)
|
||||
stop = self._receive_rpc_value(embedding_map)
|
||||
step = self._receive_rpc_value(embedding_map)
|
||||
return range(start, stop, step)
|
||||
elif tag == "k":
|
||||
name = self._read_string()
|
||||
value = self._receive_rpc_value(embedding_map)
|
||||
return RPCKeyword(name, value)
|
||||
elif tag == "O":
|
||||
return embedding_map.retrieve_object(self._read_int32())
|
||||
if tag in receivers:
|
||||
return receivers.get(tag)(self, embedding_map)
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
|
@ -378,6 +437,19 @@ class CommKernel:
|
|||
check(isinstance(value, list),
|
||||
lambda: "list")
|
||||
self._write_int32(len(value))
|
||||
tag_element = chr(tags[0])
|
||||
if tag_element == "b":
|
||||
self._write(bytes(value))
|
||||
elif tag_element == "i":
|
||||
array = numpy.array(value, '>i4')
|
||||
self._write(array.tobytes())
|
||||
elif tag_element == "I":
|
||||
array = numpy.array(value, '>i8')
|
||||
self._write(array.tobytes())
|
||||
elif tag_element == "f":
|
||||
array = numpy.array(value, '>d')
|
||||
self._write(array.tobytes())
|
||||
else:
|
||||
for elt in value:
|
||||
tags_copy = bytearray(tags)
|
||||
self._send_rpc_value(tags_copy, elt, root, function)
|
||||
|
@ -390,6 +462,19 @@ class CommKernel:
|
|||
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
|
||||
for s in value.shape:
|
||||
self._write_int32(s)
|
||||
tag_element = chr(tags[0])
|
||||
if tag_element == "b":
|
||||
self._write(value.reshape((-1,), order="C").tobytes())
|
||||
elif tag_element == "i":
|
||||
array = value.reshape((-1,), order="C").astype('>i4')
|
||||
self._write(array.tobytes())
|
||||
elif tag_element == "I":
|
||||
array = value.reshape((-1,), order="C").astype('>i8')
|
||||
self._write(array.tobytes())
|
||||
elif tag_element == "f":
|
||||
array = value.reshape((-1,), order="C").astype('>d')
|
||||
self._write(array.tobytes())
|
||||
else:
|
||||
for elt in value.reshape((-1,), order="C"):
|
||||
tags_copy = bytearray(tags)
|
||||
self._send_rpc_value(tags_copy, elt, root, function)
|
||||
|
@ -420,7 +505,7 @@ class CommKernel:
|
|||
return_tags = self._read_bytes()
|
||||
|
||||
if service_id == 0:
|
||||
service = lambda obj, attr, value: setattr(obj, attr, value)
|
||||
def service(obj, attr, value): return setattr(obj, attr, value)
|
||||
else:
|
||||
service = embedding_map.retrieve_object(service_id)
|
||||
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
|
||||
|
@ -432,15 +517,18 @@ class CommKernel:
|
|||
|
||||
try:
|
||||
result = service(*args, **kwargs)
|
||||
logger.debug("rpc service: %d %r %r = %r", service_id, args, kwargs, result)
|
||||
logger.debug("rpc service: %d %r %r = %r",
|
||||
service_id, args, kwargs, result)
|
||||
|
||||
self._write_header(Request.RPCReply)
|
||||
self._write_bytes(return_tags)
|
||||
self._send_rpc_value(bytearray(return_tags), result, result, service)
|
||||
self._send_rpc_value(bytearray(return_tags),
|
||||
result, result, service)
|
||||
except RPCReturnValueError as exn:
|
||||
raise
|
||||
except Exception as exn:
|
||||
logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn)
|
||||
logger.debug("rpc service: %d %r %r ! %r",
|
||||
service_id, args, kwargs, exn)
|
||||
|
||||
self._write_header(Request.RPCException)
|
||||
|
||||
|
|
Loading…
Reference in New Issue