forked from M-Labs/artiq
RPC: optimization by caching
This reduced the calls needed for socket send/recv.
This commit is contained in:
parent
69f0699ebd
commit
b2572003ac
@ -66,16 +66,16 @@ def _receive_list(kernel, embedding_map):
|
|||||||
tag = chr(kernel._read_int8())
|
tag = chr(kernel._read_int8())
|
||||||
if tag == "b":
|
if tag == "b":
|
||||||
buffer = kernel._read(length)
|
buffer = kernel._read(length)
|
||||||
return numpy.ndarray((length, ), 'B', buffer).tolist()
|
return list(buffer)
|
||||||
elif tag == "i":
|
elif tag == "i":
|
||||||
buffer = kernel._read(4 * length)
|
buffer = kernel._read(4 * length)
|
||||||
return numpy.ndarray((length, ), '>i4', buffer).tolist()
|
return list(struct.unpack(">%sl" % length, buffer))
|
||||||
elif tag == "I":
|
elif tag == "I":
|
||||||
buffer = kernel._read(8 * length)
|
buffer = kernel._read(8 * length)
|
||||||
return numpy.ndarray((length, ), '>i8', buffer).tolist()
|
return list(struct.unpack(">%sq" % length, buffer))
|
||||||
elif tag == "f":
|
elif tag == "f":
|
||||||
buffer = kernel._read(8 * length)
|
buffer = kernel._read(8 * length)
|
||||||
return numpy.ndarray((length, ), '>d', buffer).tolist()
|
return list(struct.unpack(">%sd" % length, buffer))
|
||||||
else:
|
else:
|
||||||
fn = receivers[tag]
|
fn = receivers[tag]
|
||||||
elems = []
|
elems = []
|
||||||
@ -178,6 +178,17 @@ class CommKernel:
|
|||||||
self._read_type = None
|
self._read_type = None
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.read_buffer = bytearray()
|
||||||
|
self.write_buffer = bytearray()
|
||||||
|
|
||||||
|
self.unpack_int32 = struct.Struct(">l").unpack
|
||||||
|
self.unpack_int64 = struct.Struct(">q").unpack
|
||||||
|
self.unpack_float64 = struct.Struct(">d").unpack
|
||||||
|
|
||||||
|
self.pack_header = struct.Struct(">lB").pack
|
||||||
|
self.pack_int32 = struct.Struct(">l").pack
|
||||||
|
self.pack_int64 = struct.Struct(">q").pack
|
||||||
|
self.pack_float64 = struct.Struct(">d").pack
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
if hasattr(self, "socket"):
|
if hasattr(self, "socket"):
|
||||||
@ -198,13 +209,18 @@ class CommKernel:
|
|||||||
#
|
#
|
||||||
|
|
||||||
def _read(self, length):
|
def _read(self, length):
|
||||||
r = bytes()
|
# cache the reads to avoid frequent call to recv
|
||||||
while len(r) < length:
|
while len(self.read_buffer) < length:
|
||||||
rn = self.socket.recv(min(8192, length - len(r)))
|
# the number is just the maximum amount
|
||||||
if not rn:
|
# when there is not much data, it would return earlier
|
||||||
raise ConnectionResetError("Connection closed")
|
diff = length - len(self.read_buffer)
|
||||||
r += rn
|
flag = 0
|
||||||
return r
|
if diff > 8192:
|
||||||
|
flag |= socket.MSG_WAITALL
|
||||||
|
self.read_buffer += self.socket.recv(8192, flag)
|
||||||
|
result = self.read_buffer[:length]
|
||||||
|
self.read_buffer = self.read_buffer[length:]
|
||||||
|
return result
|
||||||
|
|
||||||
def _read_header(self):
|
def _read_header(self):
|
||||||
self.open()
|
self.open()
|
||||||
@ -212,14 +228,14 @@ class CommKernel:
|
|||||||
# Wait for a synchronization sequence, 5a 5a 5a 5a.
|
# Wait for a synchronization sequence, 5a 5a 5a 5a.
|
||||||
sync_count = 0
|
sync_count = 0
|
||||||
while sync_count < 4:
|
while sync_count < 4:
|
||||||
(sync_byte, ) = struct.unpack("B", self._read(1))
|
sync_byte = self._read(1)[0]
|
||||||
if sync_byte == 0x5a:
|
if sync_byte == 0x5a:
|
||||||
sync_count += 1
|
sync_count += 1
|
||||||
else:
|
else:
|
||||||
sync_count = 0
|
sync_count = 0
|
||||||
|
|
||||||
# Read message header.
|
# Read message header.
|
||||||
(raw_type, ) = struct.unpack("B", self._read(1))
|
raw_type = self._read(1)[0]
|
||||||
self._read_type = Reply(raw_type)
|
self._read_type = Reply(raw_type)
|
||||||
|
|
||||||
logger.debug("receiving message: type=%r",
|
logger.debug("receiving message: type=%r",
|
||||||
@ -235,19 +251,18 @@ class CommKernel:
|
|||||||
self._read_expect(ty)
|
self._read_expect(ty)
|
||||||
|
|
||||||
def _read_int8(self):
|
def _read_int8(self):
|
||||||
(value, ) = struct.unpack("B", self._read(1))
|
return self._read(1)[0]
|
||||||
return value
|
|
||||||
|
|
||||||
def _read_int32(self):
|
def _read_int32(self):
|
||||||
(value, ) = struct.unpack(">l", self._read(4))
|
(value, ) = self.unpack_int32(self._read(4))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _read_int64(self):
|
def _read_int64(self):
|
||||||
(value, ) = struct.unpack(">q", self._read(8))
|
(value, ) = self.unpack_int64(self._read(8))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _read_float64(self):
|
def _read_float64(self):
|
||||||
(value, ) = struct.unpack(">d", self._read(8))
|
(value, ) = self.unpack_float64(self._read(8))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _read_bool(self):
|
def _read_bool(self):
|
||||||
@ -264,7 +279,15 @@ class CommKernel:
|
|||||||
#
|
#
|
||||||
|
|
||||||
def _write(self, data):
|
def _write(self, data):
|
||||||
self.socket.sendall(data)
|
self.write_buffer += data
|
||||||
|
# if the buffer is already pretty large, send it
|
||||||
|
# the block size is arbitrary, tuning it may improve performance
|
||||||
|
if len(self.write_buffer) > 4096:
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
def _flush(self):
|
||||||
|
self.socket.sendall(self.write_buffer)
|
||||||
|
self.write_buffer.clear()
|
||||||
|
|
||||||
def _write_header(self, ty):
|
def _write_header(self, ty):
|
||||||
self.open()
|
self.open()
|
||||||
@ -272,7 +295,7 @@ class CommKernel:
|
|||||||
logger.debug("sending message: type=%r", ty)
|
logger.debug("sending message: type=%r", ty)
|
||||||
|
|
||||||
# Write synchronization sequence and header.
|
# Write synchronization sequence and header.
|
||||||
self._write(struct.pack(">lB", 0x5a5a5a5a, ty.value))
|
self._write(self.pack_header(0x5a5a5a5a, ty.value))
|
||||||
|
|
||||||
def _write_empty(self, ty):
|
def _write_empty(self, ty):
|
||||||
self._write_header(ty)
|
self._write_header(ty)
|
||||||
@ -281,19 +304,19 @@ class CommKernel:
|
|||||||
self._write(chunk)
|
self._write(chunk)
|
||||||
|
|
||||||
def _write_int8(self, value):
|
def _write_int8(self, value):
|
||||||
self._write(struct.pack("B", value))
|
self._write(value)
|
||||||
|
|
||||||
def _write_int32(self, value):
|
def _write_int32(self, value):
|
||||||
self._write(struct.pack(">l", value))
|
self._write(self.pack_int32(value))
|
||||||
|
|
||||||
def _write_int64(self, value):
|
def _write_int64(self, value):
|
||||||
self._write(struct.pack(">q", value))
|
self._write(self.pack_int64(value))
|
||||||
|
|
||||||
def _write_float64(self, value):
|
def _write_float64(self, value):
|
||||||
self._write(struct.pack(">d", value))
|
self._write(self.pack_float64(value))
|
||||||
|
|
||||||
def _write_bool(self, value):
|
def _write_bool(self, value):
|
||||||
self._write(struct.pack("B", value))
|
self._write(1 if value == True else 0)
|
||||||
|
|
||||||
def _write_bytes(self, value):
|
def _write_bytes(self, value):
|
||||||
self._write_int32(len(value))
|
self._write_int32(len(value))
|
||||||
@ -308,6 +331,7 @@ class CommKernel:
|
|||||||
|
|
||||||
def check_system_info(self):
|
def check_system_info(self):
|
||||||
self._write_empty(Request.SystemInfo)
|
self._write_empty(Request.SystemInfo)
|
||||||
|
self._flush()
|
||||||
|
|
||||||
self._read_header()
|
self._read_header()
|
||||||
self._read_expect(Reply.SystemInfo)
|
self._read_expect(Reply.SystemInfo)
|
||||||
@ -332,6 +356,7 @@ class CommKernel:
|
|||||||
def load(self, kernel_library):
|
def load(self, kernel_library):
|
||||||
self._write_header(Request.LoadKernel)
|
self._write_header(Request.LoadKernel)
|
||||||
self._write_bytes(kernel_library)
|
self._write_bytes(kernel_library)
|
||||||
|
self._flush()
|
||||||
|
|
||||||
self._read_header()
|
self._read_header()
|
||||||
if self._read_type == Reply.LoadFailed:
|
if self._read_type == Reply.LoadFailed:
|
||||||
@ -341,6 +366,7 @@ class CommKernel:
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self._write_empty(Request.RunKernel)
|
self._write_empty(Request.RunKernel)
|
||||||
|
self._flush()
|
||||||
logger.debug("running kernel")
|
logger.debug("running kernel")
|
||||||
|
|
||||||
_rpc_sentinel = object()
|
_rpc_sentinel = object()
|
||||||
@ -441,14 +467,11 @@ class CommKernel:
|
|||||||
if tag_element == "b":
|
if tag_element == "b":
|
||||||
self._write(bytes(value))
|
self._write(bytes(value))
|
||||||
elif tag_element == "i":
|
elif tag_element == "i":
|
||||||
array = numpy.array(value, '>i4')
|
self._write(struct.pack(">%sl" % len(value), *value))
|
||||||
self._write(array.tobytes())
|
|
||||||
elif tag_element == "I":
|
elif tag_element == "I":
|
||||||
array = numpy.array(value, '>i8')
|
self._write(struct.pack(">%sq" % len(value), *value))
|
||||||
self._write(array.tobytes())
|
|
||||||
elif tag_element == "f":
|
elif tag_element == "f":
|
||||||
array = numpy.array(value, '>d')
|
self._write(struct.pack(">%sd" % len(value), *value))
|
||||||
self._write(array.tobytes())
|
|
||||||
else:
|
else:
|
||||||
for elt in value:
|
for elt in value:
|
||||||
tags_copy = bytearray(tags)
|
tags_copy = bytearray(tags)
|
||||||
@ -524,6 +547,7 @@ class CommKernel:
|
|||||||
self._write_bytes(return_tags)
|
self._write_bytes(return_tags)
|
||||||
self._send_rpc_value(bytearray(return_tags),
|
self._send_rpc_value(bytearray(return_tags),
|
||||||
result, result, service)
|
result, result, service)
|
||||||
|
self._flush()
|
||||||
except RPCReturnValueError as exn:
|
except RPCReturnValueError as exn:
|
||||||
raise
|
raise
|
||||||
except Exception as exn:
|
except Exception as exn:
|
||||||
@ -569,6 +593,7 @@ class CommKernel:
|
|||||||
self._write_int32(line)
|
self._write_int32(line)
|
||||||
self._write_int32(-1) # column not known
|
self._write_int32(-1) # column not known
|
||||||
self._write_string(function)
|
self._write_string(function)
|
||||||
|
self._flush()
|
||||||
|
|
||||||
def _serve_exception(self, embedding_map, symbolizer, demangler):
|
def _serve_exception(self, embedding_map, symbolizer, demangler):
|
||||||
name = self._read_string()
|
name = self._read_string()
|
||||||
|
Loading…
Reference in New Issue
Block a user