From b2572003acfef50b32aa00cd17388f00fed5f87a Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 27 Aug 2020 11:27:40 +0800 Subject: [PATCH] RPC: optimization by caching This reduced the calls needed for socket send/recv. --- artiq/coredevice/comm_kernel.py | 87 +++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/artiq/coredevice/comm_kernel.py b/artiq/coredevice/comm_kernel.py index d5096cb69..babf035b3 100644 --- a/artiq/coredevice/comm_kernel.py +++ b/artiq/coredevice/comm_kernel.py @@ -66,16 +66,16 @@ def _receive_list(kernel, embedding_map): tag = chr(kernel._read_int8()) if tag == "b": buffer = kernel._read(length) - return numpy.ndarray((length, ), 'B', buffer).tolist() + return list(buffer) elif tag == "i": buffer = kernel._read(4 * length) - return numpy.ndarray((length, ), '>i4', buffer).tolist() + return list(struct.unpack(">%sl" % length, buffer)) elif tag == "I": buffer = kernel._read(8 * length) - return numpy.ndarray((length, ), '>i8', buffer).tolist() + return list(struct.unpack(">%sq" % length, buffer)) elif tag == "f": buffer = kernel._read(8 * length) - return numpy.ndarray((length, ), '>d', buffer).tolist() + return list(struct.unpack(">%sd" % length, buffer)) else: fn = receivers[tag] elems = [] @@ -178,6 +178,17 @@ class CommKernel: self._read_type = None self.host = host self.port = port + self.read_buffer = bytearray() + self.write_buffer = bytearray() + + self.unpack_int32 = struct.Struct(">l").unpack + self.unpack_int64 = struct.Struct(">q").unpack + self.unpack_float64 = struct.Struct(">d").unpack + + self.pack_header = struct.Struct(">lB").pack + self.pack_int32 = struct.Struct(">l").pack + self.pack_int64 = struct.Struct(">q").pack + self.pack_float64 = struct.Struct(">d").pack def open(self): if hasattr(self, "socket"): @@ -198,13 +209,18 @@ class CommKernel: # def _read(self, length): - r = bytes() - while len(r) < length: - rn = self.socket.recv(min(8192, length - len(r))) - if not rn: - raise ConnectionResetError("Connection closed") - r += rn - return r + # cache the reads to avoid frequent call to recv + while len(self.read_buffer) < length: + # the number is just the maximum amount + # when there is not much data, it would return earlier + diff = length - len(self.read_buffer) + flag = 0 + if diff > 8192: + flag |= socket.MSG_WAITALL + self.read_buffer += self.socket.recv(8192, flag) + result = self.read_buffer[:length] + self.read_buffer = self.read_buffer[length:] + return result def _read_header(self): self.open() @@ -212,14 +228,14 @@ class CommKernel: # Wait for a synchronization sequence, 5a 5a 5a 5a. sync_count = 0 while sync_count < 4: - (sync_byte, ) = struct.unpack("B", self._read(1)) + sync_byte = self._read(1)[0] if sync_byte == 0x5a: sync_count += 1 else: sync_count = 0 # Read message header. - (raw_type, ) = struct.unpack("B", self._read(1)) + raw_type = self._read(1)[0] self._read_type = Reply(raw_type) logger.debug("receiving message: type=%r", @@ -235,19 +251,18 @@ class CommKernel: self._read_expect(ty) def _read_int8(self): - (value, ) = struct.unpack("B", self._read(1)) - return value + return self._read(1)[0] def _read_int32(self): - (value, ) = struct.unpack(">l", self._read(4)) + (value, ) = self.unpack_int32(self._read(4)) return value def _read_int64(self): - (value, ) = struct.unpack(">q", self._read(8)) + (value, ) = self.unpack_int64(self._read(8)) return value def _read_float64(self): - (value, ) = struct.unpack(">d", self._read(8)) + (value, ) = self.unpack_float64(self._read(8)) return value def _read_bool(self): @@ -264,7 +279,15 @@ class CommKernel: # def _write(self, data): - self.socket.sendall(data) + self.write_buffer += data + # if the buffer is already pretty large, send it + # the block size is arbitrary, tuning it may improve performance + if len(self.write_buffer) > 4096: + self._flush() + + def _flush(self): + self.socket.sendall(self.write_buffer) + self.write_buffer.clear() def _write_header(self, ty): self.open() @@ -272,7 +295,7 @@ class CommKernel: logger.debug("sending message: type=%r", ty) # Write synchronization sequence and header. - self._write(struct.pack(">lB", 0x5a5a5a5a, ty.value)) + self._write(self.pack_header(0x5a5a5a5a, ty.value)) def _write_empty(self, ty): self._write_header(ty) @@ -281,19 +304,19 @@ class CommKernel: self._write(chunk) def _write_int8(self, value): - self._write(struct.pack("B", value)) + self._write(value) def _write_int32(self, value): - self._write(struct.pack(">l", value)) + self._write(self.pack_int32(value)) def _write_int64(self, value): - self._write(struct.pack(">q", value)) + self._write(self.pack_int64(value)) def _write_float64(self, value): - self._write(struct.pack(">d", value)) + self._write(self.pack_float64(value)) def _write_bool(self, value): - self._write(struct.pack("B", value)) + self._write(1 if value == True else 0) def _write_bytes(self, value): self._write_int32(len(value)) @@ -308,6 +331,7 @@ class CommKernel: def check_system_info(self): self._write_empty(Request.SystemInfo) + self._flush() self._read_header() self._read_expect(Reply.SystemInfo) @@ -332,6 +356,7 @@ class CommKernel: def load(self, kernel_library): self._write_header(Request.LoadKernel) self._write_bytes(kernel_library) + self._flush() self._read_header() if self._read_type == Reply.LoadFailed: @@ -341,6 +366,7 @@ class CommKernel: def run(self): self._write_empty(Request.RunKernel) + self._flush() logger.debug("running kernel") _rpc_sentinel = object() @@ -441,14 +467,11 @@ class CommKernel: if tag_element == "b": self._write(bytes(value)) elif tag_element == "i": - array = numpy.array(value, '>i4') - self._write(array.tobytes()) + self._write(struct.pack(">%sl" % len(value), *value)) elif tag_element == "I": - array = numpy.array(value, '>i8') - self._write(array.tobytes()) + self._write(struct.pack(">%sq" % len(value), *value)) elif tag_element == "f": - array = numpy.array(value, '>d') - self._write(array.tobytes()) + self._write(struct.pack(">%sd" % len(value), *value)) else: for elt in value: tags_copy = bytearray(tags) @@ -524,6 +547,7 @@ class CommKernel: self._write_bytes(return_tags) self._send_rpc_value(bytearray(return_tags), result, result, service) + self._flush() except RPCReturnValueError as exn: raise except Exception as exn: @@ -569,6 +593,7 @@ class CommKernel: self._write_int32(line) self._write_int32(-1) # column not known self._write_string(function) + self._flush() def _serve_exception(self, embedding_map, symbolizer, demangler): name = self._read_string()