forked from M-Labs/artiq
coredevice.comm_*: refactor.
This commit is contained in:
parent
acc97a74f0
commit
d6ab567242
|
@ -3,7 +3,7 @@ from operator import itemgetter
|
|||
|
||||
class Comm:
|
||||
def __init__(self, dmgr):
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def switch_clock(self, external):
|
||||
pass
|
||||
|
|
|
@ -50,7 +50,11 @@ class UnsupportedDevice(Exception):
|
|||
|
||||
|
||||
class CommGeneric:
|
||||
# methods for derived classes to implement
|
||||
def __init__(self):
|
||||
self._read_type = self._write_type = None
|
||||
self._read_length = 0
|
||||
self._write_buffer = []
|
||||
|
||||
def open(self):
|
||||
"""Opens the communication channel.
|
||||
Must do nothing if already opened."""
|
||||
|
@ -70,167 +74,251 @@ class CommGeneric:
|
|||
"""Writes exactly length bytes to the communication channel.
|
||||
The channel is assumed to be opened."""
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Reader interface
|
||||
#
|
||||
|
||||
def _read_header(self):
|
||||
self.open()
|
||||
|
||||
if self._read_length > 0:
|
||||
raise IOError("Read underrun ({} bytes remaining)".
|
||||
format(self._read_length))
|
||||
|
||||
# Wait for a synchronization sequence, 5a 5a 5a 5a.
|
||||
sync_count = 0
|
||||
while sync_count < 4:
|
||||
(c, ) = struct.unpack("B", self.read(1))
|
||||
if c == 0x5a:
|
||||
(sync_byte, ) = struct.unpack("B", self.read(1))
|
||||
if sync_byte == 0x5a:
|
||||
sync_count += 1
|
||||
else:
|
||||
sync_count = 0
|
||||
length = struct.unpack(">l", self.read(4))[0]
|
||||
if not length: # inband connection close
|
||||
raise OSError("Connection closed")
|
||||
tyv = struct.unpack("B", self.read(1))[0]
|
||||
ty = _D2HMsgType(tyv)
|
||||
logger.debug("receiving message: type=%r length=%d", ty, length)
|
||||
return length, ty
|
||||
|
||||
def _write_header(self, length, ty):
|
||||
# Read message header.
|
||||
(self._read_length, ) = struct.unpack(">l", self.read(4))
|
||||
if not self._read_length: # inband connection close
|
||||
raise OSError("Connection closed")
|
||||
|
||||
(raw_type, ) = struct.unpack("B", self.read(1))
|
||||
self._read_type = _D2HMsgType(raw_type)
|
||||
|
||||
if self._read_length < 9:
|
||||
raise IOError("Read overrun in message header ({} remaining)".
|
||||
format(self._read_length))
|
||||
self._read_length -= 9
|
||||
|
||||
logger.debug("receiving message: type=%r length=%d",
|
||||
self._read_type, self._read_length)
|
||||
|
||||
def _read_expect(self, ty):
|
||||
if self._read_type != ty:
|
||||
raise IOError("Incorrect reply from device: {} (expected {})".
|
||||
format(self._read_type, ty))
|
||||
|
||||
def _read_empty(self, ty):
|
||||
self._read_header()
|
||||
self._read_expect(ty)
|
||||
|
||||
def _read_chunk(self, length):
|
||||
if self._read_length < length:
|
||||
raise IOError("Read overrun while trying to read {} bytes ({} remaining)".
|
||||
format(length, self._read_length))
|
||||
|
||||
self._read_length -= length
|
||||
return self.read(length)
|
||||
|
||||
def _read_int8(self):
|
||||
(value, ) = struct.unpack("B", self._read_chunk(1))
|
||||
return value
|
||||
|
||||
def _read_int32(self):
|
||||
(value, ) = struct.unpack(">l", self._read_chunk(4))
|
||||
return value
|
||||
|
||||
def _read_int64(self):
|
||||
(value, ) = struct.unpack(">q", self._read_chunk(8))
|
||||
return value
|
||||
|
||||
def _read_float64(self):
|
||||
(value, ) = struct.unpack(">d", self._read_chunk(8))
|
||||
return value
|
||||
|
||||
#
|
||||
# Writer interface
|
||||
#
|
||||
|
||||
def _write_header(self, ty):
|
||||
self.open()
|
||||
logger.debug("sending message: type=%r length=%d", ty, length)
|
||||
self.write(struct.pack(">ll", 0x5a5a5a5a, length))
|
||||
if ty is not None:
|
||||
self.write(struct.pack("B", ty.value))
|
||||
|
||||
logger.debug("preparing to send message: type=%r", ty)
|
||||
self._write_type = ty
|
||||
self._write_buffer = []
|
||||
|
||||
def _write_flush(self):
|
||||
# Calculate message size.
|
||||
length = sum([len(chunk) for chunk in self._write_buffer])
|
||||
logger.debug("sending message: type=%r length=%d", self._write_type, length)
|
||||
|
||||
# Write synchronization sequence, header and body.
|
||||
self.write(struct.pack(">llB", 0x5a5a5a5a,
|
||||
9 + length, self._write_type.value))
|
||||
for chunk in self._write_buffer:
|
||||
self.write(chunk)
|
||||
|
||||
def _write_empty(self, ty):
|
||||
self._write_header(ty)
|
||||
self._write_flush()
|
||||
|
||||
def _write_int8(self, value):
|
||||
self._write_buffer.append(struct.pack("B", value))
|
||||
|
||||
def _write_int32(self, value):
|
||||
self._write_buffer.append(struct.pack(">l", value))
|
||||
|
||||
def _write_int64(self, value):
|
||||
self._write_buffer.append(struct.pack(">q", value))
|
||||
|
||||
def _write_float64(self, value):
|
||||
self._write_buffer.append(struct.pack(">d", value))
|
||||
|
||||
def _write_string(self, value):
|
||||
self._write_buffer.append(value)
|
||||
|
||||
#
|
||||
# Exported APIs
|
||||
#
|
||||
|
||||
def reset_session(self):
|
||||
self._write_header(0, None)
|
||||
self.write(struct.pack(">ll", 0x5a5a5a5a, 0))
|
||||
|
||||
def check_ident(self):
|
||||
self._write_header(9, _H2DMsgType.IDENT_REQUEST)
|
||||
_, ty = self._read_header()
|
||||
if ty != _D2HMsgType.IDENT_REPLY:
|
||||
raise IOError("Incorrect reply from device: {}".format(ty))
|
||||
(reply, ) = struct.unpack("B", self.read(1))
|
||||
runtime_id = chr(reply)
|
||||
for i in range(3):
|
||||
(reply, ) = struct.unpack("B", self.read(1))
|
||||
runtime_id += chr(reply)
|
||||
if runtime_id != "AROR":
|
||||
self._write_empty(_H2DMsgType.IDENT_REQUEST)
|
||||
|
||||
self._read_header()
|
||||
self._read_expect(_D2HMsgType.IDENT_REPLY)
|
||||
runtime_id = self._read_chunk(4)
|
||||
if runtime_id != b"AROR":
|
||||
raise UnsupportedDevice("Unsupported runtime ID: {}"
|
||||
.format(runtime_id))
|
||||
|
||||
def switch_clock(self, external):
|
||||
self._write_header(10, _H2DMsgType.SWITCH_CLOCK)
|
||||
self.write(struct.pack("B", int(external)))
|
||||
_, ty = self._read_header()
|
||||
if ty != _D2HMsgType.CLOCK_SWITCH_COMPLETED:
|
||||
raise IOError("Incorrect reply from device: {}".format(ty))
|
||||
self._write_header(_H2DMsgType.SWITCH_CLOCK)
|
||||
self._write_int8(external)
|
||||
self._write_flush()
|
||||
|
||||
def load(self, kcode):
|
||||
self._write_header(len(kcode) + 9, _H2DMsgType.LOAD_LIBRARY)
|
||||
self.write(kcode)
|
||||
_, ty = self._read_header()
|
||||
if ty != _D2HMsgType.LOAD_COMPLETED:
|
||||
raise IOError("Incorrect reply from device: "+str(ty))
|
||||
self._read_empty(_D2HMsgType.CLOCK_SWITCH_COMPLETED)
|
||||
|
||||
def load(self, kernel_library):
|
||||
self._write_header(_H2DMsgType.LOAD_LIBRARY)
|
||||
self._write_string(kernel_library)
|
||||
self._write_flush()
|
||||
|
||||
self._read_empty(_D2HMsgType.LOAD_COMPLETED)
|
||||
|
||||
def run(self):
|
||||
self._write_header(9, _H2DMsgType.RUN_KERNEL)
|
||||
self._write_empty(_H2DMsgType.RUN_KERNEL)
|
||||
logger.debug("running kernel")
|
||||
|
||||
def flash_storage_read(self, key):
|
||||
self._write_header(9+len(key), _H2DMsgType.FLASH_READ_REQUEST)
|
||||
self.write(key)
|
||||
length, ty = self._read_header()
|
||||
if ty != _D2HMsgType.FLASH_READ_REPLY:
|
||||
raise IOError("Incorrect reply from device: {}".format(ty))
|
||||
value = self.read(length - 9)
|
||||
return value
|
||||
self._write_header(_H2DMsgType.FLASH_READ_REQUEST)
|
||||
self._write_string(key)
|
||||
self._write_flush()
|
||||
|
||||
self._read_header()
|
||||
self._read_expect(_D2HMsgType.FLASH_READ_REPLY)
|
||||
return self._read_chunk(self._read_length)
|
||||
|
||||
def flash_storage_write(self, key, value):
|
||||
self._write_header(9+len(key)+1+len(value),
|
||||
_H2DMsgType.FLASH_WRITE_REQUEST)
|
||||
self.write(key)
|
||||
self.write(b"\x00")
|
||||
self.write(value)
|
||||
_, ty = self._read_header()
|
||||
if ty != _D2HMsgType.FLASH_OK_REPLY:
|
||||
if ty == _D2HMsgType.FLASH_ERROR_REPLY:
|
||||
self._write_header(_H2DMsgType.FLASH_WRITE_REQUEST)
|
||||
self._write_string(key)
|
||||
self._write_string(b"\x00")
|
||||
self._write_string(value)
|
||||
self._write_flush()
|
||||
|
||||
self._read_header()
|
||||
if self._read_type == _D2HMsgType.FLASH_ERROR_REPLY:
|
||||
raise IOError("Flash storage is full")
|
||||
else:
|
||||
raise IOError("Incorrect reply from device: {}".format(ty))
|
||||
self._read_expect(_D2HMsgType.FLASH_OK_REPLY)
|
||||
|
||||
def flash_storage_erase(self):
|
||||
self._write_header(9, _H2DMsgType.FLASH_ERASE_REQUEST)
|
||||
_, ty = self._read_header()
|
||||
if ty != _D2HMsgType.FLASH_OK_REPLY:
|
||||
raise IOError("Incorrect reply from device: {}".format(ty))
|
||||
self._write_empty(_H2DMsgType.FLASH_ERASE_REQUEST)
|
||||
|
||||
self._read_empty(_D2HMsgType.FLASH_OK_REPLY)
|
||||
|
||||
def flash_storage_remove(self, key):
|
||||
self._write_header(9+len(key), _H2DMsgType.FLASH_REMOVE_REQUEST)
|
||||
self.write(key)
|
||||
_, ty = self._read_header()
|
||||
if ty != _D2HMsgType.FLASH_OK_REPLY:
|
||||
raise IOError("Incorrect reply from device: {}".format(ty))
|
||||
self._write_header(_H2DMsgType.FLASH_REMOVE_REQUEST)
|
||||
self._write_string(key)
|
||||
self._write_flush()
|
||||
|
||||
def _receive_rpc_value(self, type_tag):
|
||||
if type_tag == "n":
|
||||
self._read_empty(_D2HMsgType.FLASH_OK_REPLY)
|
||||
|
||||
def _receive_rpc_value(self, tag):
|
||||
if tag == "n":
|
||||
return None
|
||||
if type_tag == "b":
|
||||
return bool(struct.unpack("B", self.read(1))[0])
|
||||
if type_tag == "i":
|
||||
return struct.unpack(">l", self.read(4))[0]
|
||||
if type_tag == "I":
|
||||
return struct.unpack(">q", self.read(8))[0]
|
||||
if type_tag == "f":
|
||||
return struct.unpack(">d", self.read(8))[0]
|
||||
if type_tag == "F":
|
||||
n, d = struct.unpack(">qq", self.read(16))
|
||||
return Fraction(n, d)
|
||||
elif tag == "b":
|
||||
return bool(self._read_int8())
|
||||
elif tag == "i":
|
||||
return self._read_int32()
|
||||
elif tag == "I":
|
||||
return self._read_int64()
|
||||
elif tag == "f":
|
||||
return self._read_float64()
|
||||
elif tag == "F":
|
||||
numerator = self._read_int64()
|
||||
denominator = self._read_int64()
|
||||
return Fraction(numerator, denominator)
|
||||
elif tag == "l":
|
||||
elt_tag = chr(self._read_int8())
|
||||
length = self._read_int32()
|
||||
return [self._receive_rpc_value(elt_tag) for _ in range(length)]
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}", tag)
|
||||
|
||||
def _receive_rpc_values(self):
|
||||
r = []
|
||||
result = []
|
||||
while True:
|
||||
type_tag = chr(struct.unpack("B", self.read(1))[0])
|
||||
if type_tag == "\x00":
|
||||
return r
|
||||
elif type_tag == "l":
|
||||
elt_type_tag = chr(struct.unpack("B", self.read(1))[0])
|
||||
length = struct.unpack(">l", self.read(4))[0]
|
||||
r.append([self._receive_rpc_value(elt_type_tag)
|
||||
for i in range(length)])
|
||||
tag = chr(self._read_int8())
|
||||
if tag == "\x00":
|
||||
return result
|
||||
else:
|
||||
r.append(self._receive_rpc_value(type_tag))
|
||||
result.append(self._receive_rpc_value(tag))
|
||||
|
||||
def _serve_rpc(self, rpc_map):
|
||||
rpc_num = struct.unpack(">l", self.read(4))[0]
|
||||
service = self._read_int32()
|
||||
args = self._receive_rpc_values()
|
||||
logger.debug("rpc service: %d %r", rpc_num, args)
|
||||
eid, r = rpc_wrapper.run_rpc(rpc_map[rpc_num], args)
|
||||
self._write_header(9+2*4, _H2DMsgType.RPC_REPLY)
|
||||
self.write(struct.pack(">ll", eid, r))
|
||||
logger.debug("rpc service: %d %r == %r (eid %d)", rpc_num, args,
|
||||
r, eid)
|
||||
logger.debug("rpc service: %d %r", service, args)
|
||||
|
||||
eid, result = rpc_wrapper.run_rpc(rpc_map[rpc_num], args)
|
||||
logger.debug("rpc service: %d %r == %r (eid %d)", service, args,
|
||||
result, eid)
|
||||
|
||||
self._write_header(_H2DMsgType.RPC_REPLY)
|
||||
self._write_int32(eid)
|
||||
self._write_int32(result)
|
||||
self._write_flush()
|
||||
|
||||
def _serve_exception(self):
|
||||
eid, p0, p1, p2 = struct.unpack(">lqqq", self.read(4+3*8))
|
||||
eid = self._read_int32()
|
||||
params = [self._read_int64() for _ in range(3)]
|
||||
rpc_wrapper.filter_rpc_exception(eid)
|
||||
raise exception(self.core, p0, p1, p2)
|
||||
raise exception(self.core, *params)
|
||||
|
||||
def serve(self, rpc_map):
|
||||
while True:
|
||||
_, ty = self._read_header()
|
||||
if ty == _D2HMsgType.RPC_REQUEST:
|
||||
self._read_header()
|
||||
if self._read_type == _D2HMsgType.RPC_REQUEST:
|
||||
self._serve_rpc(rpc_map)
|
||||
elif ty == _D2HMsgType.KERNEL_EXCEPTION:
|
||||
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
|
||||
self._serve_exception()
|
||||
elif ty == _D2HMsgType.KERNEL_FINISHED:
|
||||
return
|
||||
else:
|
||||
raise IOError("Incorrect request from device: "+str(ty))
|
||||
self._read_expect(_D2HMsgType.KERNEL_FINISHED)
|
||||
return
|
||||
|
||||
def get_log(self):
|
||||
self._write_header(9, _H2DMsgType.LOG_REQUEST)
|
||||
length, ty = self._read_header()
|
||||
if ty != _D2HMsgType.LOG_REPLY:
|
||||
raise IOError("Incorrect request from device: "+str(ty))
|
||||
r = ""
|
||||
for i in range(length - 9):
|
||||
c = struct.unpack("B", self.read(1))[0]
|
||||
if c:
|
||||
r += chr(c)
|
||||
return r
|
||||
self._write_empty(_H2DMsgType.LOG_REQUEST)
|
||||
|
||||
self._read_header()
|
||||
self._read_expect(_D2HMsgType.LOG_REPLY)
|
||||
return self._read_chunk(self._read_length).decode('utf-8')
|
||||
|
|
|
@ -10,6 +10,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class Comm(CommGeneric):
|
||||
def __init__(self, dmgr, serial_dev, baud_rate=115200):
|
||||
super().__init__()
|
||||
self.serial_dev = serial_dev
|
||||
self.baud_rate = baud_rate
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class Comm(CommGeneric):
|
||||
def __init__(self, dmgr, host, port=1381):
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
|
|
Loading…
Reference in New Issue