forked from M-Labs/artiq
1
0
Fork 0

coredevice.comm_*: refactor.

This commit is contained in:
whitequark 2015-08-07 16:15:44 +03:00
parent acc97a74f0
commit d6ab567242
4 changed files with 204 additions and 114 deletions

View File

@ -3,7 +3,7 @@ from operator import itemgetter
class Comm: class Comm:
def __init__(self, dmgr): def __init__(self, dmgr):
pass super().__init__()
def switch_clock(self, external): def switch_clock(self, external):
pass pass

View File

@ -50,7 +50,11 @@ class UnsupportedDevice(Exception):
class CommGeneric: class CommGeneric:
# methods for derived classes to implement def __init__(self):
self._read_type = self._write_type = None
self._read_length = 0
self._write_buffer = []
def open(self): def open(self):
"""Opens the communication channel. """Opens the communication channel.
Must do nothing if already opened.""" Must do nothing if already opened."""
@ -70,167 +74,251 @@ class CommGeneric:
"""Writes exactly length bytes to the communication channel. """Writes exactly length bytes to the communication channel.
The channel is assumed to be opened.""" The channel is assumed to be opened."""
raise NotImplementedError raise NotImplementedError
#
# Reader interface
# #
def _read_header(self): def _read_header(self):
self.open() self.open()
if self._read_length > 0:
raise IOError("Read underrun ({} bytes remaining)".
format(self._read_length))
# Wait for a synchronization sequence, 5a 5a 5a 5a.
sync_count = 0 sync_count = 0
while sync_count < 4: while sync_count < 4:
(c, ) = struct.unpack("B", self.read(1)) (sync_byte, ) = struct.unpack("B", self.read(1))
if c == 0x5a: if sync_byte == 0x5a:
sync_count += 1 sync_count += 1
else: else:
sync_count = 0 sync_count = 0
length = struct.unpack(">l", self.read(4))[0]
if not length: # inband connection close
raise OSError("Connection closed")
tyv = struct.unpack("B", self.read(1))[0]
ty = _D2HMsgType(tyv)
logger.debug("receiving message: type=%r length=%d", ty, length)
return length, ty
def _write_header(self, length, ty): # Read message header.
(self._read_length, ) = struct.unpack(">l", self.read(4))
if not self._read_length: # inband connection close
raise OSError("Connection closed")
(raw_type, ) = struct.unpack("B", self.read(1))
self._read_type = _D2HMsgType(raw_type)
if self._read_length < 9:
raise IOError("Read overrun in message header ({} remaining)".
format(self._read_length))
self._read_length -= 9
logger.debug("receiving message: type=%r length=%d",
self._read_type, self._read_length)
def _read_expect(self, ty):
if self._read_type != ty:
raise IOError("Incorrect reply from device: {} (expected {})".
format(self._read_type, ty))
def _read_empty(self, ty):
self._read_header()
self._read_expect(ty)
def _read_chunk(self, length):
if self._read_length < length:
raise IOError("Read overrun while trying to read {} bytes ({} remaining)".
format(length, self._read_length))
self._read_length -= length
return self.read(length)
def _read_int8(self):
(value, ) = struct.unpack("B", self._read_chunk(1))
return value
def _read_int32(self):
(value, ) = struct.unpack(">l", self._read_chunk(4))
return value
def _read_int64(self):
(value, ) = struct.unpack(">q", self._read_chunk(8))
return value
def _read_float64(self):
(value, ) = struct.unpack(">d", self._read_chunk(8))
return value
#
# Writer interface
#
def _write_header(self, ty):
self.open() self.open()
logger.debug("sending message: type=%r length=%d", ty, length)
self.write(struct.pack(">ll", 0x5a5a5a5a, length)) logger.debug("preparing to send message: type=%r", ty)
if ty is not None: self._write_type = ty
self.write(struct.pack("B", ty.value)) self._write_buffer = []
def _write_flush(self):
# Calculate message size.
length = sum([len(chunk) for chunk in self._write_buffer])
logger.debug("sending message: type=%r length=%d", self._write_type, length)
# Write synchronization sequence, header and body.
self.write(struct.pack(">llB", 0x5a5a5a5a,
9 + length, self._write_type.value))
for chunk in self._write_buffer:
self.write(chunk)
def _write_empty(self, ty):
self._write_header(ty)
self._write_flush()
def _write_int8(self, value):
self._write_buffer.append(struct.pack("B", value))
def _write_int32(self, value):
self._write_buffer.append(struct.pack(">l", value))
def _write_int64(self, value):
self._write_buffer.append(struct.pack(">q", value))
def _write_float64(self, value):
self._write_buffer.append(struct.pack(">d", value))
def _write_string(self, value):
self._write_buffer.append(value)
#
# Exported APIs
#
def reset_session(self): def reset_session(self):
self._write_header(0, None) self.write(struct.pack(">ll", 0x5a5a5a5a, 0))
def check_ident(self): def check_ident(self):
self._write_header(9, _H2DMsgType.IDENT_REQUEST) self._write_empty(_H2DMsgType.IDENT_REQUEST)
_, ty = self._read_header()
if ty != _D2HMsgType.IDENT_REPLY: self._read_header()
raise IOError("Incorrect reply from device: {}".format(ty)) self._read_expect(_D2HMsgType.IDENT_REPLY)
(reply, ) = struct.unpack("B", self.read(1)) runtime_id = self._read_chunk(4)
runtime_id = chr(reply) if runtime_id != b"AROR":
for i in range(3):
(reply, ) = struct.unpack("B", self.read(1))
runtime_id += chr(reply)
if runtime_id != "AROR":
raise UnsupportedDevice("Unsupported runtime ID: {}" raise UnsupportedDevice("Unsupported runtime ID: {}"
.format(runtime_id)) .format(runtime_id))
def switch_clock(self, external): def switch_clock(self, external):
self._write_header(10, _H2DMsgType.SWITCH_CLOCK) self._write_header(_H2DMsgType.SWITCH_CLOCK)
self.write(struct.pack("B", int(external))) self._write_int8(external)
_, ty = self._read_header() self._write_flush()
if ty != _D2HMsgType.CLOCK_SWITCH_COMPLETED:
raise IOError("Incorrect reply from device: {}".format(ty))
def load(self, kcode): self._read_empty(_D2HMsgType.CLOCK_SWITCH_COMPLETED)
self._write_header(len(kcode) + 9, _H2DMsgType.LOAD_LIBRARY)
self.write(kcode) def load(self, kernel_library):
_, ty = self._read_header() self._write_header(_H2DMsgType.LOAD_LIBRARY)
if ty != _D2HMsgType.LOAD_COMPLETED: self._write_string(kernel_library)
raise IOError("Incorrect reply from device: "+str(ty)) self._write_flush()
self._read_empty(_D2HMsgType.LOAD_COMPLETED)
def run(self): def run(self):
self._write_header(9, _H2DMsgType.RUN_KERNEL) self._write_empty(_H2DMsgType.RUN_KERNEL)
logger.debug("running kernel") logger.debug("running kernel")
def flash_storage_read(self, key): def flash_storage_read(self, key):
self._write_header(9+len(key), _H2DMsgType.FLASH_READ_REQUEST) self._write_header(_H2DMsgType.FLASH_READ_REQUEST)
self.write(key) self._write_string(key)
length, ty = self._read_header() self._write_flush()
if ty != _D2HMsgType.FLASH_READ_REPLY:
raise IOError("Incorrect reply from device: {}".format(ty)) self._read_header()
value = self.read(length - 9) self._read_expect(_D2HMsgType.FLASH_READ_REPLY)
return value return self._read_chunk(self._read_length)
def flash_storage_write(self, key, value): def flash_storage_write(self, key, value):
self._write_header(9+len(key)+1+len(value), self._write_header(_H2DMsgType.FLASH_WRITE_REQUEST)
_H2DMsgType.FLASH_WRITE_REQUEST) self._write_string(key)
self.write(key) self._write_string(b"\x00")
self.write(b"\x00") self._write_string(value)
self.write(value) self._write_flush()
_, ty = self._read_header()
if ty != _D2HMsgType.FLASH_OK_REPLY: self._read_header()
if ty == _D2HMsgType.FLASH_ERROR_REPLY: if self._read_type == _D2HMsgType.FLASH_ERROR_REPLY:
raise IOError("Flash storage is full") raise IOError("Flash storage is full")
else: else:
raise IOError("Incorrect reply from device: {}".format(ty)) self._read_expect(_D2HMsgType.FLASH_OK_REPLY)
def flash_storage_erase(self): def flash_storage_erase(self):
self._write_header(9, _H2DMsgType.FLASH_ERASE_REQUEST) self._write_empty(_H2DMsgType.FLASH_ERASE_REQUEST)
_, ty = self._read_header()
if ty != _D2HMsgType.FLASH_OK_REPLY: self._read_empty(_D2HMsgType.FLASH_OK_REPLY)
raise IOError("Incorrect reply from device: {}".format(ty))
def flash_storage_remove(self, key): def flash_storage_remove(self, key):
self._write_header(9+len(key), _H2DMsgType.FLASH_REMOVE_REQUEST) self._write_header(_H2DMsgType.FLASH_REMOVE_REQUEST)
self.write(key) self._write_string(key)
_, ty = self._read_header() self._write_flush()
if ty != _D2HMsgType.FLASH_OK_REPLY:
raise IOError("Incorrect reply from device: {}".format(ty))
def _receive_rpc_value(self, type_tag): self._read_empty(_D2HMsgType.FLASH_OK_REPLY)
if type_tag == "n":
def _receive_rpc_value(self, tag):
if tag == "n":
return None return None
if type_tag == "b": elif tag == "b":
return bool(struct.unpack("B", self.read(1))[0]) return bool(self._read_int8())
if type_tag == "i": elif tag == "i":
return struct.unpack(">l", self.read(4))[0] return self._read_int32()
if type_tag == "I": elif tag == "I":
return struct.unpack(">q", self.read(8))[0] return self._read_int64()
if type_tag == "f": elif tag == "f":
return struct.unpack(">d", self.read(8))[0] return self._read_float64()
if type_tag == "F": elif tag == "F":
n, d = struct.unpack(">qq", self.read(16)) numerator = self._read_int64()
return Fraction(n, d) denominator = self._read_int64()
return Fraction(numerator, denominator)
elif tag == "l":
elt_tag = chr(self._read_int8())
length = self._read_int32()
return [self._receive_rpc_value(elt_tag) for _ in range(length)]
else:
raise IOError("Unknown RPC value tag: {}", tag)
def _receive_rpc_values(self): def _receive_rpc_values(self):
r = [] result = []
while True: while True:
type_tag = chr(struct.unpack("B", self.read(1))[0]) tag = chr(self._read_int8())
if type_tag == "\x00": if tag == "\x00":
return r return result
elif type_tag == "l":
elt_type_tag = chr(struct.unpack("B", self.read(1))[0])
length = struct.unpack(">l", self.read(4))[0]
r.append([self._receive_rpc_value(elt_type_tag)
for i in range(length)])
else: else:
r.append(self._receive_rpc_value(type_tag)) result.append(self._receive_rpc_value(tag))
def _serve_rpc(self, rpc_map): def _serve_rpc(self, rpc_map):
rpc_num = struct.unpack(">l", self.read(4))[0] service = self._read_int32()
args = self._receive_rpc_values() args = self._receive_rpc_values()
logger.debug("rpc service: %d %r", rpc_num, args) logger.debug("rpc service: %d %r", service, args)
eid, r = rpc_wrapper.run_rpc(rpc_map[rpc_num], args)
self._write_header(9+2*4, _H2DMsgType.RPC_REPLY) eid, result = rpc_wrapper.run_rpc(rpc_map[rpc_num], args)
self.write(struct.pack(">ll", eid, r)) logger.debug("rpc service: %d %r == %r (eid %d)", service, args,
logger.debug("rpc service: %d %r == %r (eid %d)", rpc_num, args, result, eid)
r, eid)
self._write_header(_H2DMsgType.RPC_REPLY)
self._write_int32(eid)
self._write_int32(result)
self._write_flush()
def _serve_exception(self): def _serve_exception(self):
eid, p0, p1, p2 = struct.unpack(">lqqq", self.read(4+3*8)) eid = self._read_int32()
params = [self._read_int64() for _ in range(3)]
rpc_wrapper.filter_rpc_exception(eid) rpc_wrapper.filter_rpc_exception(eid)
raise exception(self.core, p0, p1, p2) raise exception(self.core, *params)
def serve(self, rpc_map): def serve(self, rpc_map):
while True: while True:
_, ty = self._read_header() self._read_header()
if ty == _D2HMsgType.RPC_REQUEST: if self._read_type == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(rpc_map) self._serve_rpc(rpc_map)
elif ty == _D2HMsgType.KERNEL_EXCEPTION: elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception() self._serve_exception()
elif ty == _D2HMsgType.KERNEL_FINISHED:
return
else: else:
raise IOError("Incorrect request from device: "+str(ty)) self._read_expect(_D2HMsgType.KERNEL_FINISHED)
return
def get_log(self): def get_log(self):
self._write_header(9, _H2DMsgType.LOG_REQUEST) self._write_empty(_H2DMsgType.LOG_REQUEST)
length, ty = self._read_header()
if ty != _D2HMsgType.LOG_REPLY: self._read_header()
raise IOError("Incorrect request from device: "+str(ty)) self._read_expect(_D2HMsgType.LOG_REPLY)
r = "" return self._read_chunk(self._read_length).decode('utf-8')
for i in range(length - 9):
c = struct.unpack("B", self.read(1))[0]
if c:
r += chr(c)
return r

View File

@ -10,6 +10,7 @@ logger = logging.getLogger(__name__)
class Comm(CommGeneric): class Comm(CommGeneric):
def __init__(self, dmgr, serial_dev, baud_rate=115200): def __init__(self, dmgr, serial_dev, baud_rate=115200):
super().__init__()
self.serial_dev = serial_dev self.serial_dev = serial_dev
self.baud_rate = baud_rate self.baud_rate = baud_rate

View File

@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
class Comm(CommGeneric): class Comm(CommGeneric):
def __init__(self, dmgr, host, port=1381): def __init__(self, dmgr, host, port=1381):
super().__init__()
self.host = host self.host = host
self.port = port self.port = port