From d6ab567242c5e921feeb9e50c827c67924a69e8b Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 7 Aug 2015 16:15:44 +0300 Subject: [PATCH] coredevice.comm_*: refactor. --- artiq/coredevice/comm_dummy.py | 2 +- artiq/coredevice/comm_generic.py | 314 ++++++++++++++++++++----------- artiq/coredevice/comm_serial.py | 1 + artiq/coredevice/comm_tcp.py | 1 + 4 files changed, 204 insertions(+), 114 deletions(-) diff --git a/artiq/coredevice/comm_dummy.py b/artiq/coredevice/comm_dummy.py index 5b0c35c46..072c9bdd8 100644 --- a/artiq/coredevice/comm_dummy.py +++ b/artiq/coredevice/comm_dummy.py @@ -3,7 +3,7 @@ from operator import itemgetter class Comm: def __init__(self, dmgr): - pass + super().__init__() def switch_clock(self, external): pass diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index da5d91017..b8ac7fa07 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -50,7 +50,11 @@ class UnsupportedDevice(Exception): class CommGeneric: - # methods for derived classes to implement + def __init__(self): + self._read_type = self._write_type = None + self._read_length = 0 + self._write_buffer = [] + def open(self): """Opens the communication channel. Must do nothing if already opened.""" @@ -70,167 +74,251 @@ class CommGeneric: """Writes exactly length bytes to the communication channel. The channel is assumed to be opened.""" raise NotImplementedError + + # + # Reader interface # def _read_header(self): self.open() + if self._read_length > 0: + raise IOError("Read underrun ({} bytes remaining)". + format(self._read_length)) + + # Wait for a synchronization sequence, 5a 5a 5a 5a. sync_count = 0 while sync_count < 4: - (c, ) = struct.unpack("B", self.read(1)) - if c == 0x5a: + (sync_byte, ) = struct.unpack("B", self.read(1)) + if sync_byte == 0x5a: sync_count += 1 else: sync_count = 0 - length = struct.unpack(">l", self.read(4))[0] - if not length: # inband connection close - raise OSError("Connection closed") - tyv = struct.unpack("B", self.read(1))[0] - ty = _D2HMsgType(tyv) - logger.debug("receiving message: type=%r length=%d", ty, length) - return length, ty - def _write_header(self, length, ty): + # Read message header. + (self._read_length, ) = struct.unpack(">l", self.read(4)) + if not self._read_length: # inband connection close + raise OSError("Connection closed") + + (raw_type, ) = struct.unpack("B", self.read(1)) + self._read_type = _D2HMsgType(raw_type) + + if self._read_length < 9: + raise IOError("Read overrun in message header ({} remaining)". + format(self._read_length)) + self._read_length -= 9 + + logger.debug("receiving message: type=%r length=%d", + self._read_type, self._read_length) + + def _read_expect(self, ty): + if self._read_type != ty: + raise IOError("Incorrect reply from device: {} (expected {})". + format(self._read_type, ty)) + + def _read_empty(self, ty): + self._read_header() + self._read_expect(ty) + + def _read_chunk(self, length): + if self._read_length < length: + raise IOError("Read overrun while trying to read {} bytes ({} remaining)". + format(length, self._read_length)) + + self._read_length -= length + return self.read(length) + + def _read_int8(self): + (value, ) = struct.unpack("B", self._read_chunk(1)) + return value + + def _read_int32(self): + (value, ) = struct.unpack(">l", self._read_chunk(4)) + return value + + def _read_int64(self): + (value, ) = struct.unpack(">q", self._read_chunk(8)) + return value + + def _read_float64(self): + (value, ) = struct.unpack(">d", self._read_chunk(8)) + return value + + # + # Writer interface + # + + def _write_header(self, ty): self.open() - logger.debug("sending message: type=%r length=%d", ty, length) - self.write(struct.pack(">ll", 0x5a5a5a5a, length)) - if ty is not None: - self.write(struct.pack("B", ty.value)) + + logger.debug("preparing to send message: type=%r", ty) + self._write_type = ty + self._write_buffer = [] + + def _write_flush(self): + # Calculate message size. + length = sum([len(chunk) for chunk in self._write_buffer]) + logger.debug("sending message: type=%r length=%d", self._write_type, length) + + # Write synchronization sequence, header and body. + self.write(struct.pack(">llB", 0x5a5a5a5a, + 9 + length, self._write_type.value)) + for chunk in self._write_buffer: + self.write(chunk) + + def _write_empty(self, ty): + self._write_header(ty) + self._write_flush() + + def _write_int8(self, value): + self._write_buffer.append(struct.pack("B", value)) + + def _write_int32(self, value): + self._write_buffer.append(struct.pack(">l", value)) + + def _write_int64(self, value): + self._write_buffer.append(struct.pack(">q", value)) + + def _write_float64(self, value): + self._write_buffer.append(struct.pack(">d", value)) + + def _write_string(self, value): + self._write_buffer.append(value) + + # + # Exported APIs + # def reset_session(self): - self._write_header(0, None) + self.write(struct.pack(">ll", 0x5a5a5a5a, 0)) def check_ident(self): - self._write_header(9, _H2DMsgType.IDENT_REQUEST) - _, ty = self._read_header() - if ty != _D2HMsgType.IDENT_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) - (reply, ) = struct.unpack("B", self.read(1)) - runtime_id = chr(reply) - for i in range(3): - (reply, ) = struct.unpack("B", self.read(1)) - runtime_id += chr(reply) - if runtime_id != "AROR": + self._write_empty(_H2DMsgType.IDENT_REQUEST) + + self._read_header() + self._read_expect(_D2HMsgType.IDENT_REPLY) + runtime_id = self._read_chunk(4) + if runtime_id != b"AROR": raise UnsupportedDevice("Unsupported runtime ID: {}" .format(runtime_id)) def switch_clock(self, external): - self._write_header(10, _H2DMsgType.SWITCH_CLOCK) - self.write(struct.pack("B", int(external))) - _, ty = self._read_header() - if ty != _D2HMsgType.CLOCK_SWITCH_COMPLETED: - raise IOError("Incorrect reply from device: {}".format(ty)) + self._write_header(_H2DMsgType.SWITCH_CLOCK) + self._write_int8(external) + self._write_flush() - def load(self, kcode): - self._write_header(len(kcode) + 9, _H2DMsgType.LOAD_LIBRARY) - self.write(kcode) - _, ty = self._read_header() - if ty != _D2HMsgType.LOAD_COMPLETED: - raise IOError("Incorrect reply from device: "+str(ty)) + self._read_empty(_D2HMsgType.CLOCK_SWITCH_COMPLETED) + + def load(self, kernel_library): + self._write_header(_H2DMsgType.LOAD_LIBRARY) + self._write_string(kernel_library) + self._write_flush() + + self._read_empty(_D2HMsgType.LOAD_COMPLETED) def run(self): - self._write_header(9, _H2DMsgType.RUN_KERNEL) + self._write_empty(_H2DMsgType.RUN_KERNEL) logger.debug("running kernel") def flash_storage_read(self, key): - self._write_header(9+len(key), _H2DMsgType.FLASH_READ_REQUEST) - self.write(key) - length, ty = self._read_header() - if ty != _D2HMsgType.FLASH_READ_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) - value = self.read(length - 9) - return value + self._write_header(_H2DMsgType.FLASH_READ_REQUEST) + self._write_string(key) + self._write_flush() + + self._read_header() + self._read_expect(_D2HMsgType.FLASH_READ_REPLY) + return self._read_chunk(self._read_length) def flash_storage_write(self, key, value): - self._write_header(9+len(key)+1+len(value), - _H2DMsgType.FLASH_WRITE_REQUEST) - self.write(key) - self.write(b"\x00") - self.write(value) - _, ty = self._read_header() - if ty != _D2HMsgType.FLASH_OK_REPLY: - if ty == _D2HMsgType.FLASH_ERROR_REPLY: - raise IOError("Flash storage is full") - else: - raise IOError("Incorrect reply from device: {}".format(ty)) + self._write_header(_H2DMsgType.FLASH_WRITE_REQUEST) + self._write_string(key) + self._write_string(b"\x00") + self._write_string(value) + self._write_flush() + + self._read_header() + if self._read_type == _D2HMsgType.FLASH_ERROR_REPLY: + raise IOError("Flash storage is full") + else: + self._read_expect(_D2HMsgType.FLASH_OK_REPLY) def flash_storage_erase(self): - self._write_header(9, _H2DMsgType.FLASH_ERASE_REQUEST) - _, ty = self._read_header() - if ty != _D2HMsgType.FLASH_OK_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) + self._write_empty(_H2DMsgType.FLASH_ERASE_REQUEST) + + self._read_empty(_D2HMsgType.FLASH_OK_REPLY) def flash_storage_remove(self, key): - self._write_header(9+len(key), _H2DMsgType.FLASH_REMOVE_REQUEST) - self.write(key) - _, ty = self._read_header() - if ty != _D2HMsgType.FLASH_OK_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) + self._write_header(_H2DMsgType.FLASH_REMOVE_REQUEST) + self._write_string(key) + self._write_flush() - def _receive_rpc_value(self, type_tag): - if type_tag == "n": + self._read_empty(_D2HMsgType.FLASH_OK_REPLY) + + def _receive_rpc_value(self, tag): + if tag == "n": return None - if type_tag == "b": - return bool(struct.unpack("B", self.read(1))[0]) - if type_tag == "i": - return struct.unpack(">l", self.read(4))[0] - if type_tag == "I": - return struct.unpack(">q", self.read(8))[0] - if type_tag == "f": - return struct.unpack(">d", self.read(8))[0] - if type_tag == "F": - n, d = struct.unpack(">qq", self.read(16)) - return Fraction(n, d) + elif tag == "b": + return bool(self._read_int8()) + elif tag == "i": + return self._read_int32() + elif tag == "I": + return self._read_int64() + elif tag == "f": + return self._read_float64() + elif tag == "F": + numerator = self._read_int64() + denominator = self._read_int64() + return Fraction(numerator, denominator) + elif tag == "l": + elt_tag = chr(self._read_int8()) + length = self._read_int32() + return [self._receive_rpc_value(elt_tag) for _ in range(length)] + else: + raise IOError("Unknown RPC value tag: {}", tag) def _receive_rpc_values(self): - r = [] + result = [] while True: - type_tag = chr(struct.unpack("B", self.read(1))[0]) - if type_tag == "\x00": - return r - elif type_tag == "l": - elt_type_tag = chr(struct.unpack("B", self.read(1))[0]) - length = struct.unpack(">l", self.read(4))[0] - r.append([self._receive_rpc_value(elt_type_tag) - for i in range(length)]) + tag = chr(self._read_int8()) + if tag == "\x00": + return result else: - r.append(self._receive_rpc_value(type_tag)) + result.append(self._receive_rpc_value(tag)) def _serve_rpc(self, rpc_map): - rpc_num = struct.unpack(">l", self.read(4))[0] + service = self._read_int32() args = self._receive_rpc_values() - logger.debug("rpc service: %d %r", rpc_num, args) - eid, r = rpc_wrapper.run_rpc(rpc_map[rpc_num], args) - self._write_header(9+2*4, _H2DMsgType.RPC_REPLY) - self.write(struct.pack(">ll", eid, r)) - logger.debug("rpc service: %d %r == %r (eid %d)", rpc_num, args, - r, eid) + logger.debug("rpc service: %d %r", service, args) + + eid, result = rpc_wrapper.run_rpc(rpc_map[rpc_num], args) + logger.debug("rpc service: %d %r == %r (eid %d)", service, args, + result, eid) + + self._write_header(_H2DMsgType.RPC_REPLY) + self._write_int32(eid) + self._write_int32(result) + self._write_flush() def _serve_exception(self): - eid, p0, p1, p2 = struct.unpack(">lqqq", self.read(4+3*8)) + eid = self._read_int32() + params = [self._read_int64() for _ in range(3)] rpc_wrapper.filter_rpc_exception(eid) - raise exception(self.core, p0, p1, p2) + raise exception(self.core, *params) def serve(self, rpc_map): while True: - _, ty = self._read_header() - if ty == _D2HMsgType.RPC_REQUEST: + self._read_header() + if self._read_type == _D2HMsgType.RPC_REQUEST: self._serve_rpc(rpc_map) - elif ty == _D2HMsgType.KERNEL_EXCEPTION: + elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION: self._serve_exception() - elif ty == _D2HMsgType.KERNEL_FINISHED: - return else: - raise IOError("Incorrect request from device: "+str(ty)) + self._read_expect(_D2HMsgType.KERNEL_FINISHED) + return def get_log(self): - self._write_header(9, _H2DMsgType.LOG_REQUEST) - length, ty = self._read_header() - if ty != _D2HMsgType.LOG_REPLY: - raise IOError("Incorrect request from device: "+str(ty)) - r = "" - for i in range(length - 9): - c = struct.unpack("B", self.read(1))[0] - if c: - r += chr(c) - return r + self._write_empty(_H2DMsgType.LOG_REQUEST) + + self._read_header() + self._read_expect(_D2HMsgType.LOG_REPLY) + return self._read_chunk(self._read_length).decode('utf-8') diff --git a/artiq/coredevice/comm_serial.py b/artiq/coredevice/comm_serial.py index 70218da14..bf710c29a 100644 --- a/artiq/coredevice/comm_serial.py +++ b/artiq/coredevice/comm_serial.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) class Comm(CommGeneric): def __init__(self, dmgr, serial_dev, baud_rate=115200): + super().__init__() self.serial_dev = serial_dev self.baud_rate = baud_rate diff --git a/artiq/coredevice/comm_tcp.py b/artiq/coredevice/comm_tcp.py index eda672750..f5a97658d 100644 --- a/artiq/coredevice/comm_tcp.py +++ b/artiq/coredevice/comm_tcp.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) class Comm(CommGeneric): def __init__(self, dmgr, host, port=1381): + super().__init__() self.host = host self.port = port