compiler: improved rpc performance for list and array

1. Removed duplicated tags before each elements.
2. Use numpy functions to speedup parsing.
This commit is contained in:
pca006132 2020-08-18 17:01:28 +08:00 committed by Sébastien Bourdeauducq
parent cfddc13294
commit 7181ff66a6
1 changed files with 158 additions and 70 deletions

View File

@ -43,9 +43,11 @@ class Reply(Enum):
class UnsupportedDevice(Exception): class UnsupportedDevice(Exception):
pass pass
class LoadError(Exception): class LoadError(Exception):
pass pass
class RPCReturnValueError(ValueError): class RPCReturnValueError(ValueError):
pass pass
@ -53,6 +55,105 @@ class RPCReturnValueError(ValueError):
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value']) RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
def _receive_fraction(kernel, embedding_map):
numerator = kernel._read_int64()
denominator = kernel._read_int64()
return Fraction(numerator, denominator)
def _receive_list(kernel, embedding_map):
length = kernel._read_int32()
tag = chr(kernel._read_int8())
if tag == "b":
buffer = kernel._read(length)
return numpy.ndarray((length, ), 'B', buffer).tolist()
elif tag == "i":
buffer = kernel._read(4 * length)
return numpy.ndarray((length, ), '>i4', buffer).tolist()
elif tag == "I":
buffer = kernel._read(8 * length)
return numpy.ndarray((length, ), '>i8', buffer).tolist()
elif tag == "f":
buffer = kernel._read(8 * length)
return numpy.ndarray((length, ), '>d', buffer).tolist()
else:
fn = receivers[tag]
elems = []
for _ in range(length):
# discard tag, as our device would still send the tag for each
# non-primitive elements.
kernel._read_int8()
item = fn(kernel, embedding_map)
elems.append(item)
return elems
def _receive_array(kernel, embedding_map):
num_dims = kernel._read_int8()
shape = tuple(kernel._read_int32() for _ in range(num_dims))
tag = chr(kernel._read_int8())
fn = receivers[tag]
length = numpy.prod(shape)
if tag == "b":
buffer = kernel._read(length)
elems = numpy.ndarray((length, ), 'B', buffer)
elif tag == "i":
buffer = kernel._read(4 * length)
elems = numpy.ndarray((length, ), '>i4', buffer)
elif tag == "I":
buffer = kernel._read(8 * length)
elems = numpy.ndarray((length, ), '>i8', buffer)
elif tag == "f":
buffer = kernel._read(8 * length)
elems = numpy.ndarray((length, ), '>d', buffer)
else:
fn = receivers[tag]
elems = []
for _ in range(numpy.prod(shape)):
# discard the tag
kernel._read_int8()
item = fn(kernel, embedding_map)
elems.append(item)
elems = numpy.array(elems)
return elems.reshape(shape)
def _receive_range(kernel, embedding_map):
start = kernel._receive_rpc_value(embedding_map)
stop = kernel._receive_rpc_value(embedding_map)
step = kernel._receive_rpc_value(embedding_map)
return range(start, stop, step)
def _receive_keyword(kernel, embedding_map):
name = kernel._read_string()
value = kernel._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
receivers = {
"\x00": lambda kernel, embedding_map: kernel._rpc_sentinel,
"t": lambda kernel, embedding_map:
tuple(kernel._receive_rpc_value(embedding_map)
for _ in range(kernel._read_int8())),
"n": lambda kernel, embedding_map: None,
"b": lambda kernel, embedding_map: bool(kernel._read_int8()),
"i": lambda kernel, embedding_map: numpy.int32(kernel._read_int32()),
"I": lambda kernel, embedding_map: numpy.int32(kernel._read_int64()),
"f": lambda kernel, embedding_map: kernel._read_float64(),
"s": lambda kernel, embedding_map: kernel._read_string(),
"B": lambda kernel, embedding_map: kernel._read_bytes(),
"A": lambda kernel, embedding_map: kernel._read_bytes(),
"O": lambda kernel, embedding_map:
embedding_map.retrieve_object(kernel._read_int32()),
"F": _receive_fraction,
"l": _receive_list,
"a": _receive_array,
"r": _receive_range,
"k": _receive_keyword
}
class CommKernelDummy: class CommKernelDummy:
def __init__(self): def __init__(self):
pass pass
@ -247,50 +348,8 @@ class CommKernel:
# See rpc_proto.rs and compiler/ir.py:rpc_tag. # See rpc_proto.rs and compiler/ir.py:rpc_tag.
def _receive_rpc_value(self, embedding_map): def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8()) tag = chr(self._read_int8())
if tag == "\x00": if tag in receivers:
return self._rpc_sentinel return receivers.get(tag)(self, embedding_map)
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
elif tag == "n":
return None
elif tag == "b":
return bool(self._read_int8())
elif tag == "i":
return numpy.int32(self._read_int32())
elif tag == "I":
return numpy.int64(self._read_int64())
elif tag == "f":
return self._read_float64()
elif tag == "F":
numerator = self._read_int64()
denominator = self._read_int64()
return Fraction(numerator, denominator)
elif tag == "s":
return self._read_string()
elif tag == "B":
return self._read_bytes()
elif tag == "A":
return self._read_bytes()
elif tag == "l":
length = self._read_int32()
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "a":
num_dims = self._read_int8()
shape = tuple(self._read_int32() for _ in range(num_dims))
elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))]
return numpy.array(elems).reshape(shape)
elif tag == "r":
start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(embedding_map)
step = self._receive_rpc_value(embedding_map)
return range(start, stop, step)
elif tag == "k":
name = self._read_string()
value = self._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
elif tag == "O":
return embedding_map.retrieve_object(self._read_int32())
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
@ -357,8 +416,8 @@ class CommKernel:
self._write_float64(value) self._write_float64(value)
elif tag == "F": elif tag == "F":
check(isinstance(value, Fraction) and check(isinstance(value, Fraction) and
(-2**63 < value.numerator < 2**63-1) and (-2**63 < value.numerator < 2**63-1) and
(-2**63 < value.denominator < 2**63-1), (-2**63 < value.denominator < 2**63-1),
lambda: "64-bit Fraction") lambda: "64-bit Fraction")
self._write_int64(value.numerator) self._write_int64(value.numerator)
self._write_int64(value.denominator) self._write_int64(value.denominator)
@ -378,21 +437,47 @@ class CommKernel:
check(isinstance(value, list), check(isinstance(value, list),
lambda: "list") lambda: "list")
self._write_int32(len(value)) self._write_int32(len(value))
for elt in value: tag_element = chr(tags[0])
tags_copy = bytearray(tags) if tag_element == "b":
self._send_rpc_value(tags_copy, elt, root, function) self._write(bytes(value))
elif tag_element == "i":
array = numpy.array(value, '>i4')
self._write(array.tobytes())
elif tag_element == "I":
array = numpy.array(value, '>i8')
self._write(array.tobytes())
elif tag_element == "f":
array = numpy.array(value, '>d')
self._write(array.tobytes())
else:
for elt in value:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags) self._skip_rpc_value(tags)
elif tag == "a": elif tag == "a":
check(isinstance(value, numpy.ndarray), check(isinstance(value, numpy.ndarray),
lambda: "numpy.ndarray") lambda: "numpy.ndarray")
num_dims = tags.pop(0) num_dims = tags.pop(0)
check(num_dims == len(value.shape), check(num_dims == len(value.shape),
lambda: "{}-dimensional numpy.ndarray".format(num_dims)) lambda: "{}-dimensional numpy.ndarray".format(num_dims))
for s in value.shape: for s in value.shape:
self._write_int32(s) self._write_int32(s)
for elt in value.reshape((-1,), order="C"): tag_element = chr(tags[0])
tags_copy = bytearray(tags) if tag_element == "b":
self._send_rpc_value(tags_copy, elt, root, function) self._write(value.reshape((-1,), order="C").tobytes())
elif tag_element == "i":
array = value.reshape((-1,), order="C").astype('>i4')
self._write(array.tobytes())
elif tag_element == "I":
array = value.reshape((-1,), order="C").astype('>i8')
self._write(array.tobytes())
elif tag_element == "f":
array = value.reshape((-1,), order="C").astype('>d')
self._write(array.tobytes())
else:
for elt in value.reshape((-1,), order="C"):
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags) self._skip_rpc_value(tags)
elif tag == "r": elif tag == "r":
check(isinstance(value, range), check(isinstance(value, range),
@ -414,15 +499,15 @@ class CommKernel:
return msg return msg
def _serve_rpc(self, embedding_map): def _serve_rpc(self, embedding_map):
is_async = self._read_bool() is_async = self._read_bool()
service_id = self._read_int32() service_id = self._read_int32()
args, kwargs = self._receive_rpc_args(embedding_map) args, kwargs = self._receive_rpc_args(embedding_map)
return_tags = self._read_bytes() return_tags = self._read_bytes()
if service_id == 0: if service_id == 0:
service = lambda obj, attr, value: setattr(obj, attr, value) def service(obj, attr, value): return setattr(obj, attr, value)
else: else:
service = embedding_map.retrieve_object(service_id) service = embedding_map.retrieve_object(service_id)
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service, logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
(" (async)" if is_async else ""), args, kwargs, return_tags) (" (async)" if is_async else ""), args, kwargs, return_tags)
@ -432,15 +517,18 @@ class CommKernel:
try: try:
result = service(*args, **kwargs) result = service(*args, **kwargs)
logger.debug("rpc service: %d %r %r = %r", service_id, args, kwargs, result) logger.debug("rpc service: %d %r %r = %r",
service_id, args, kwargs, result)
self._write_header(Request.RPCReply) self._write_header(Request.RPCReply)
self._write_bytes(return_tags) self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result, service) self._send_rpc_value(bytearray(return_tags),
result, result, service)
except RPCReturnValueError as exn: except RPCReturnValueError as exn:
raise raise
except Exception as exn: except Exception as exn:
logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn) logger.debug("rpc service: %d %r %r ! %r",
service_id, args, kwargs, exn)
self._write_header(Request.RPCException) self._write_header(Request.RPCException)
@ -479,23 +567,23 @@ class CommKernel:
assert False assert False
self._write_string(filename) self._write_string(filename)
self._write_int32(line) self._write_int32(line)
self._write_int32(-1) # column not known self._write_int32(-1) # column not known
self._write_string(function) self._write_string(function)
def _serve_exception(self, embedding_map, symbolizer, demangler): def _serve_exception(self, embedding_map, symbolizer, demangler):
name = self._read_string() name = self._read_string()
message = self._read_string() message = self._read_string()
params = [self._read_int64() for _ in range(3)] params = [self._read_int64() for _ in range(3)]
filename = self._read_string() filename = self._read_string()
line = self._read_int32() line = self._read_int32()
column = self._read_int32() column = self._read_int32()
function = self._read_string() function = self._read_string()
backtrace = [self._read_int32() for _ in range(self._read_int32())] backtrace = [self._read_int32() for _ in range(self._read_int32())]
traceback = list(reversed(symbolizer(backtrace))) + \ traceback = list(reversed(symbolizer(backtrace))) + \
[(filename, line, column, *demangler([function]), None)] [(filename, line, column, *demangler([function]), None)]
core_exn = exceptions.CoreException(name, message, params, traceback) core_exn = exceptions.CoreException(name, message, params, traceback)
if core_exn.id == 0: if core_exn.id == 0: