From 7181ff66a6fd7b3aecddf752b6c3eff5fddf5a7d Mon Sep 17 00:00:00 2001 From: pca006132 Date: Tue, 18 Aug 2020 17:01:28 +0800 Subject: [PATCH] compiler: improved rpc performance for list and array 1. Removed duplicated tags before each elements. 2. Use numpy functions to speedup parsing. --- artiq/coredevice/comm_kernel.py | 228 ++++++++++++++++++++++---------- 1 file changed, 158 insertions(+), 70 deletions(-) diff --git a/artiq/coredevice/comm_kernel.py b/artiq/coredevice/comm_kernel.py index 41bddd553..b28f79272 100644 --- a/artiq/coredevice/comm_kernel.py +++ b/artiq/coredevice/comm_kernel.py @@ -43,9 +43,11 @@ class Reply(Enum): class UnsupportedDevice(Exception): pass + class LoadError(Exception): pass + class RPCReturnValueError(ValueError): pass @@ -53,6 +55,105 @@ class RPCReturnValueError(ValueError): RPCKeyword = namedtuple('RPCKeyword', ['name', 'value']) +def _receive_fraction(kernel, embedding_map): + numerator = kernel._read_int64() + denominator = kernel._read_int64() + return Fraction(numerator, denominator) + + +def _receive_list(kernel, embedding_map): + length = kernel._read_int32() + tag = chr(kernel._read_int8()) + if tag == "b": + buffer = kernel._read(length) + return numpy.ndarray((length, ), 'B', buffer).tolist() + elif tag == "i": + buffer = kernel._read(4 * length) + return numpy.ndarray((length, ), '>i4', buffer).tolist() + elif tag == "I": + buffer = kernel._read(8 * length) + return numpy.ndarray((length, ), '>i8', buffer).tolist() + elif tag == "f": + buffer = kernel._read(8 * length) + return numpy.ndarray((length, ), '>d', buffer).tolist() + else: + fn = receivers[tag] + elems = [] + for _ in range(length): + # discard tag, as our device would still send the tag for each + # non-primitive elements. + kernel._read_int8() + item = fn(kernel, embedding_map) + elems.append(item) + return elems + + +def _receive_array(kernel, embedding_map): + num_dims = kernel._read_int8() + shape = tuple(kernel._read_int32() for _ in range(num_dims)) + tag = chr(kernel._read_int8()) + fn = receivers[tag] + length = numpy.prod(shape) + if tag == "b": + buffer = kernel._read(length) + elems = numpy.ndarray((length, ), 'B', buffer) + elif tag == "i": + buffer = kernel._read(4 * length) + elems = numpy.ndarray((length, ), '>i4', buffer) + elif tag == "I": + buffer = kernel._read(8 * length) + elems = numpy.ndarray((length, ), '>i8', buffer) + elif tag == "f": + buffer = kernel._read(8 * length) + elems = numpy.ndarray((length, ), '>d', buffer) + else: + fn = receivers[tag] + elems = [] + for _ in range(numpy.prod(shape)): + # discard the tag + kernel._read_int8() + item = fn(kernel, embedding_map) + elems.append(item) + elems = numpy.array(elems) + return elems.reshape(shape) + + +def _receive_range(kernel, embedding_map): + start = kernel._receive_rpc_value(embedding_map) + stop = kernel._receive_rpc_value(embedding_map) + step = kernel._receive_rpc_value(embedding_map) + return range(start, stop, step) + + +def _receive_keyword(kernel, embedding_map): + name = kernel._read_string() + value = kernel._receive_rpc_value(embedding_map) + return RPCKeyword(name, value) + + +receivers = { + "\x00": lambda kernel, embedding_map: kernel._rpc_sentinel, + "t": lambda kernel, embedding_map: + tuple(kernel._receive_rpc_value(embedding_map) + for _ in range(kernel._read_int8())), + "n": lambda kernel, embedding_map: None, + "b": lambda kernel, embedding_map: bool(kernel._read_int8()), + "i": lambda kernel, embedding_map: numpy.int32(kernel._read_int32()), + "I": lambda kernel, embedding_map: numpy.int32(kernel._read_int64()), + "f": lambda kernel, embedding_map: kernel._read_float64(), + "s": lambda kernel, embedding_map: kernel._read_string(), + "B": lambda kernel, embedding_map: kernel._read_bytes(), + "A": lambda kernel, embedding_map: kernel._read_bytes(), + "O": lambda kernel, embedding_map: + embedding_map.retrieve_object(kernel._read_int32()), + "F": _receive_fraction, + "l": _receive_list, + "a": _receive_array, + "r": _receive_range, + "k": _receive_keyword +} + + class CommKernelDummy: def __init__(self): pass @@ -247,50 +348,8 @@ class CommKernel: # See rpc_proto.rs and compiler/ir.py:rpc_tag. def _receive_rpc_value(self, embedding_map): tag = chr(self._read_int8()) - if tag == "\x00": - return self._rpc_sentinel - elif tag == "t": - length = self._read_int8() - return tuple(self._receive_rpc_value(embedding_map) for _ in range(length)) - elif tag == "n": - return None - elif tag == "b": - return bool(self._read_int8()) - elif tag == "i": - return numpy.int32(self._read_int32()) - elif tag == "I": - return numpy.int64(self._read_int64()) - elif tag == "f": - return self._read_float64() - elif tag == "F": - numerator = self._read_int64() - denominator = self._read_int64() - return Fraction(numerator, denominator) - elif tag == "s": - return self._read_string() - elif tag == "B": - return self._read_bytes() - elif tag == "A": - return self._read_bytes() - elif tag == "l": - length = self._read_int32() - return [self._receive_rpc_value(embedding_map) for _ in range(length)] - elif tag == "a": - num_dims = self._read_int8() - shape = tuple(self._read_int32() for _ in range(num_dims)) - elems = [self._receive_rpc_value(embedding_map) for _ in range(numpy.prod(shape))] - return numpy.array(elems).reshape(shape) - elif tag == "r": - start = self._receive_rpc_value(embedding_map) - stop = self._receive_rpc_value(embedding_map) - step = self._receive_rpc_value(embedding_map) - return range(start, stop, step) - elif tag == "k": - name = self._read_string() - value = self._receive_rpc_value(embedding_map) - return RPCKeyword(name, value) - elif tag == "O": - return embedding_map.retrieve_object(self._read_int32()) + if tag in receivers: + return receivers.get(tag)(self, embedding_map) else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) @@ -357,8 +416,8 @@ class CommKernel: self._write_float64(value) elif tag == "F": check(isinstance(value, Fraction) and - (-2**63 < value.numerator < 2**63-1) and - (-2**63 < value.denominator < 2**63-1), + (-2**63 < value.numerator < 2**63-1) and + (-2**63 < value.denominator < 2**63-1), lambda: "64-bit Fraction") self._write_int64(value.numerator) self._write_int64(value.denominator) @@ -378,21 +437,47 @@ class CommKernel: check(isinstance(value, list), lambda: "list") self._write_int32(len(value)) - for elt in value: - tags_copy = bytearray(tags) - self._send_rpc_value(tags_copy, elt, root, function) + tag_element = chr(tags[0]) + if tag_element == "b": + self._write(bytes(value)) + elif tag_element == "i": + array = numpy.array(value, '>i4') + self._write(array.tobytes()) + elif tag_element == "I": + array = numpy.array(value, '>i8') + self._write(array.tobytes()) + elif tag_element == "f": + array = numpy.array(value, '>d') + self._write(array.tobytes()) + else: + for elt in value: + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, elt, root, function) self._skip_rpc_value(tags) elif tag == "a": check(isinstance(value, numpy.ndarray), lambda: "numpy.ndarray") num_dims = tags.pop(0) check(num_dims == len(value.shape), - lambda: "{}-dimensional numpy.ndarray".format(num_dims)) + lambda: "{}-dimensional numpy.ndarray".format(num_dims)) for s in value.shape: self._write_int32(s) - for elt in value.reshape((-1,), order="C"): - tags_copy = bytearray(tags) - self._send_rpc_value(tags_copy, elt, root, function) + tag_element = chr(tags[0]) + if tag_element == "b": + self._write(value.reshape((-1,), order="C").tobytes()) + elif tag_element == "i": + array = value.reshape((-1,), order="C").astype('>i4') + self._write(array.tobytes()) + elif tag_element == "I": + array = value.reshape((-1,), order="C").astype('>i8') + self._write(array.tobytes()) + elif tag_element == "f": + array = value.reshape((-1,), order="C").astype('>d') + self._write(array.tobytes()) + else: + for elt in value.reshape((-1,), order="C"): + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, elt, root, function) self._skip_rpc_value(tags) elif tag == "r": check(isinstance(value, range), @@ -414,15 +499,15 @@ class CommKernel: return msg def _serve_rpc(self, embedding_map): - is_async = self._read_bool() - service_id = self._read_int32() + is_async = self._read_bool() + service_id = self._read_int32() args, kwargs = self._receive_rpc_args(embedding_map) - return_tags = self._read_bytes() + return_tags = self._read_bytes() if service_id == 0: - service = lambda obj, attr, value: setattr(obj, attr, value) + def service(obj, attr, value): return setattr(obj, attr, value) else: - service = embedding_map.retrieve_object(service_id) + service = embedding_map.retrieve_object(service_id) logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service, (" (async)" if is_async else ""), args, kwargs, return_tags) @@ -432,15 +517,18 @@ class CommKernel: try: result = service(*args, **kwargs) - logger.debug("rpc service: %d %r %r = %r", service_id, args, kwargs, result) + logger.debug("rpc service: %d %r %r = %r", + service_id, args, kwargs, result) self._write_header(Request.RPCReply) self._write_bytes(return_tags) - self._send_rpc_value(bytearray(return_tags), result, result, service) + self._send_rpc_value(bytearray(return_tags), + result, result, service) except RPCReturnValueError as exn: raise except Exception as exn: - logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn) + logger.debug("rpc service: %d %r %r ! %r", + service_id, args, kwargs, exn) self._write_header(Request.RPCException) @@ -479,23 +567,23 @@ class CommKernel: assert False self._write_string(filename) self._write_int32(line) - self._write_int32(-1) # column not known + self._write_int32(-1) # column not known self._write_string(function) def _serve_exception(self, embedding_map, symbolizer, demangler): - name = self._read_string() - message = self._read_string() - params = [self._read_int64() for _ in range(3)] + name = self._read_string() + message = self._read_string() + params = [self._read_int64() for _ in range(3)] - filename = self._read_string() - line = self._read_int32() - column = self._read_int32() - function = self._read_string() + filename = self._read_string() + line = self._read_int32() + column = self._read_int32() + function = self._read_string() backtrace = [self._read_int32() for _ in range(self._read_int32())] traceback = list(reversed(symbolizer(backtrace))) + \ - [(filename, line, column, *demangler([function]), None)] + [(filename, line, column, *demangler([function]), None)] core_exn = exceptions.CoreException(name, message, params, traceback) if core_exn.id == 0: