compiler: improved rpc performance for list and array

1. Removed duplicated tags before each elements.
2. Use numpy functions to speedup parsing.
This commit is contained in:
pca006132 2020-08-18 17:01:28 +08:00 committed by Sébastien Bourdeauducq
parent cfddc13294
commit 7181ff66a6

View File

@ -43,9 +43,11 @@ class Reply(Enum):
class UnsupportedDevice(Exception):
pass
class LoadError(Exception):
pass
class RPCReturnValueError(ValueError):
pass
@ -53,6 +55,105 @@ class RPCReturnValueError(ValueError):
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
def _receive_fraction(kernel, embedding_map):
numerator = kernel._read_int64()
denominator = kernel._read_int64()
return Fraction(numerator, denominator)
def _receive_list(kernel, embedding_map):
length = kernel._read_int32()
tag = chr(kernel._read_int8())
if tag == "b":
buffer = kernel._read(length)
return numpy.ndarray((length, ), 'B', buffer).tolist()
elif tag == "i":
buffer = kernel._read(4 * length)
return numpy.ndarray((length, ), '>i4', buffer).tolist()
elif tag == "I":
buffer = kernel._read(8 * length)
return numpy.ndarray((length, ), '>i8', buffer).tolist()
elif tag == "f":
buffer = kernel._read(8 * length)
return numpy.ndarray((length, ), '>d', buffer).tolist()
else:
fn = receivers[tag]
elems = []
for _ in range(length):
# discard tag, as our device would still send the tag for each
# non-primitive elements.
kernel._read_int8()
item = fn(kernel, embedding_map)
elems.append(item)
return elems
def _receive_array(kernel, embedding_map):
num_dims = kernel._read_int8()
shape = tuple(kernel._read_int32() for _ in range(num_dims))
tag = chr(kernel._read_int8())
fn = receivers[tag]
length = numpy.prod(shape)
if tag == "b":
buffer = kernel._read(length)
elems = numpy.ndarray((length, ), 'B', buffer)
elif tag == "i":
buffer = kernel._read(4 * length)
elems = numpy.ndarray((length, ), '>i4', buffer)
elif tag == "I":
buffer = kernel._read(8 * length)
elems = numpy.ndarray((length, ), '>i8', buffer)
elif tag == "f":
buffer = kernel._read(8 * length)
elems = numpy.ndarray((length, ), '>d', buffer)
else:
fn = receivers[tag]
elems = []
for _ in range(numpy.prod(shape)):
# discard the tag
kernel._read_int8()
item = fn(kernel, embedding_map)
elems.append(item)
elems = numpy.array(elems)
return elems.reshape(shape)
def _receive_range(kernel, embedding_map):
start = kernel._receive_rpc_value(embedding_map)
stop = kernel._receive_rpc_value(embedding_map)
step = kernel._receive_rpc_value(embedding_map)
return range(start, stop, step)
def _receive_keyword(kernel, embedding_map):
name = kernel._read_string()
value = kernel._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
receivers = {
"\x00": lambda kernel, embedding_map: kernel._rpc_sentinel,
"t": lambda kernel, embedding_map:
tuple(kernel._receive_rpc_value(embedding_map)
for _ in range(kernel._read_int8())),
"n": lambda kernel, embedding_map: None,
"b": lambda kernel, embedding_map: bool(kernel._read_int8()),
"i": lambda kernel, embedding_map: numpy.int32(kernel._read_int32()),
"I": lambda kernel, embedding_map: numpy.int32(kernel._read_int64()),
"f": lambda kernel, embedding_map: kernel._read_float64(),
"s": lambda kernel, embedding_map: kernel._read_string(),
"B": lambda kernel, embedding_map: kernel._read_bytes(),
"A": lambda kernel, embedding_map: kernel._read_bytes(),
"O": lambda kernel, embedding_map:
embedding_map.retrieve_object(kernel._read_int32()),
"F": _receive_fraction,
"l": _receive_list,
"a": _receive_array,
"r": _receive_range,
"k": _receive_keyword
}
class CommKernelDummy:
def __init__(self):
pass
@ -247,50 +348,8 @@ class CommKernel:
# See rpc_proto.rs and compiler/ir.py:rpc_tag.
def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8())
if tag == "\x00":
return self._rpc_sentinel
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
elif tag == "n":
return None
elif tag == "b":
return bool(self._read_int8())
elif tag == "i":
return numpy.int32(self._read_int32())
elif tag == "I":
return numpy.int64(self._read_int64())
elif tag == "f":
return self._read_float64()
elif tag == "F":
numerator = self._read_int64()
denominator = self._read_int64()
return Fraction(numerator, denominator)
elif tag == "s":
return self._read_string()
elif tag == "B":
return self._read_bytes()
elif tag == "A":
return self._read_bytes()
elif tag == "l":
length = self._read_int32()
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "a":
num_dims = self._read_int8()
shape = tuple(self._read_int32() for _ in range(num_dims))
elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))]
return numpy.array(elems).reshape(shape)
elif tag == "r":
start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(embedding_map)
step = self._receive_rpc_value(embedding_map)
return range(start, stop, step)
elif tag == "k":
name = self._read_string()
value = self._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
elif tag == "O":
return embedding_map.retrieve_object(self._read_int32())
if tag in receivers:
return receivers.get(tag)(self, embedding_map)
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
@ -378,6 +437,19 @@ class CommKernel:
check(isinstance(value, list),
lambda: "list")
self._write_int32(len(value))
tag_element = chr(tags[0])
if tag_element == "b":
self._write(bytes(value))
elif tag_element == "i":
array = numpy.array(value, '>i4')
self._write(array.tobytes())
elif tag_element == "I":
array = numpy.array(value, '>i8')
self._write(array.tobytes())
elif tag_element == "f":
array = numpy.array(value, '>d')
self._write(array.tobytes())
else:
for elt in value:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
@ -390,6 +462,19 @@ class CommKernel:
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
for s in value.shape:
self._write_int32(s)
tag_element = chr(tags[0])
if tag_element == "b":
self._write(value.reshape((-1,), order="C").tobytes())
elif tag_element == "i":
array = value.reshape((-1,), order="C").astype('>i4')
self._write(array.tobytes())
elif tag_element == "I":
array = value.reshape((-1,), order="C").astype('>i8')
self._write(array.tobytes())
elif tag_element == "f":
array = value.reshape((-1,), order="C").astype('>d')
self._write(array.tobytes())
else:
for elt in value.reshape((-1,), order="C"):
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
@ -420,7 +505,7 @@ class CommKernel:
return_tags = self._read_bytes()
if service_id == 0:
service = lambda obj, attr, value: setattr(obj, attr, value)
def service(obj, attr, value): return setattr(obj, attr, value)
else:
service = embedding_map.retrieve_object(service_id)
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
@ -432,15 +517,18 @@ class CommKernel:
try:
result = service(*args, **kwargs)
logger.debug("rpc service: %d %r %r = %r", service_id, args, kwargs, result)
logger.debug("rpc service: %d %r %r = %r",
service_id, args, kwargs, result)
self._write_header(Request.RPCReply)
self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result, service)
self._send_rpc_value(bytearray(return_tags),
result, result, service)
except RPCReturnValueError as exn:
raise
except Exception as exn:
logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn)
logger.debug("rpc service: %d %r %r ! %r",
service_id, args, kwargs, exn)
self._write_header(Request.RPCException)