artiq/artiq/coredevice/comm_generic.py

500 lines
16 KiB
Python

import struct
import logging
import traceback
from enum import Enum
from fractions import Fraction
from artiq.language import core as core_language
from artiq import __version__ as software_version
logger = logging.getLogger(__name__)
class _H2DMsgType(Enum):
LOG_REQUEST = 1
LOG_CLEAR = 2
IDENT_REQUEST = 3
SWITCH_CLOCK = 4
LOAD_LIBRARY = 5
RUN_KERNEL = 6
RPC_REPLY = 7
RPC_EXCEPTION = 8
FLASH_READ_REQUEST = 9
FLASH_WRITE_REQUEST = 10
FLASH_ERASE_REQUEST = 11
FLASH_REMOVE_REQUEST = 12
class _D2HMsgType(Enum):
LOG_REPLY = 1
IDENT_REPLY = 2
CLOCK_SWITCH_COMPLETED = 3
CLOCK_SWITCH_FAILED = 4
LOAD_COMPLETED = 5
LOAD_FAILED = 6
KERNEL_FINISHED = 7
KERNEL_STARTUP_FAILED = 8
KERNEL_EXCEPTION = 9
RPC_REQUEST = 10
FLASH_READ_REPLY = 11
FLASH_OK_REPLY = 12
FLASH_ERROR_REPLY = 13
class UnsupportedDevice(Exception):
pass
class RPCReturnValueError(ValueError):
pass
class CommGeneric:
def __init__(self):
self._read_type = self._write_type = None
self._read_length = 0
self._write_buffer = []
def open(self):
"""Opens the communication channel.
Must do nothing if already opened."""
raise NotImplementedError
def close(self):
"""Closes the communication channel.
Must do nothing if already closed."""
raise NotImplementedError
def read(self, length):
"""Reads exactly length bytes from the communication channel.
The channel is assumed to be opened."""
raise NotImplementedError
def write(self, data):
"""Writes exactly length bytes to the communication channel.
The channel is assumed to be opened."""
raise NotImplementedError
#
# Reader interface
#
def _read_header(self):
self.open()
if self._read_length > 0:
raise IOError("Read underrun ({} bytes remaining)".
format(self._read_length))
# Wait for a synchronization sequence, 5a 5a 5a 5a.
sync_count = 0
while sync_count < 4:
(sync_byte, ) = struct.unpack("B", self.read(1))
if sync_byte == 0x5a:
sync_count += 1
else:
sync_count = 0
# Read message header.
(self._read_length, ) = struct.unpack(">l", self.read(4))
if not self._read_length: # inband connection close
raise OSError("Connection closed")
(raw_type, ) = struct.unpack("B", self.read(1))
self._read_type = _D2HMsgType(raw_type)
if self._read_length < 9:
raise IOError("Read overrun in message header ({} remaining)".
format(self._read_length))
self._read_length -= 9
logger.debug("receiving message: type=%r length=%d",
self._read_type, self._read_length)
def _read_expect(self, ty):
if self._read_type != ty:
raise IOError("Incorrect reply from device: {} (expected {})".
format(self._read_type, ty))
def _read_empty(self, ty):
self._read_header()
self._read_expect(ty)
def _read_chunk(self, length):
if self._read_length < length:
raise IOError("Read overrun while trying to read {} bytes ({} remaining)"
" in packet {}".
format(length, self._read_length, self._read_type))
self._read_length -= length
return self.read(length)
def _read_int8(self):
(value, ) = struct.unpack("B", self._read_chunk(1))
return value
def _read_int32(self):
(value, ) = struct.unpack(">l", self._read_chunk(4))
return value
def _read_int64(self):
(value, ) = struct.unpack(">q", self._read_chunk(8))
return value
def _read_float64(self):
(value, ) = struct.unpack(">d", self._read_chunk(8))
return value
def _read_bytes(self):
return self._read_chunk(self._read_int32())
def _read_string(self):
return self._read_bytes()[:-1].decode('utf-8')
#
# Writer interface
#
def _write_header(self, ty):
self.open()
logger.debug("preparing to send message: type=%r", ty)
self._write_type = ty
self._write_buffer = []
def _write_flush(self):
# Calculate message size.
length = sum([len(chunk) for chunk in self._write_buffer])
logger.debug("sending message: type=%r length=%d", self._write_type, length)
# Write synchronization sequence, header and body.
self.write(struct.pack(">llB", 0x5a5a5a5a,
9 + length, self._write_type.value))
for chunk in self._write_buffer:
self.write(chunk)
def _write_empty(self, ty):
self._write_header(ty)
self._write_flush()
def _write_chunk(self, chunk):
self._write_buffer.append(chunk)
def _write_int8(self, value):
self._write_buffer.append(struct.pack("B", value))
def _write_int32(self, value):
self._write_buffer.append(struct.pack(">l", value))
def _write_int64(self, value):
self._write_buffer.append(struct.pack(">q", value))
def _write_float64(self, value):
self._write_buffer.append(struct.pack(">d", value))
def _write_bytes(self, value):
self._write_int32(len(value))
self._write_buffer.append(value)
def _write_string(self, value):
self._write_bytes(value.encode("utf-8") + b"\0")
#
# Exported APIs
#
def reset_session(self):
self.write(struct.pack(">ll", 0x5a5a5a5a, 0))
def check_ident(self):
self._write_empty(_H2DMsgType.IDENT_REQUEST)
self._read_header()
self._read_expect(_D2HMsgType.IDENT_REPLY)
runtime_id = self._read_chunk(4)
if runtime_id != b"AROR":
raise UnsupportedDevice("Unsupported runtime ID: {}"
.format(runtime_id))
gateware_version = self._read_chunk(self._read_length).decode("utf-8")
if gateware_version != software_version:
logger.warning("Mismatch between gateware (%s) "
"and software (%s) versions",
gateware_version, software_version)
def switch_clock(self, external):
self._write_header(_H2DMsgType.SWITCH_CLOCK)
self._write_int8(external)
self._write_flush()
self._read_empty(_D2HMsgType.CLOCK_SWITCH_COMPLETED)
def get_log(self):
self._write_empty(_H2DMsgType.LOG_REQUEST)
self._read_header()
self._read_expect(_D2HMsgType.LOG_REPLY)
return self._read_chunk(self._read_length).decode("utf-8")
def clear_log(self):
self._write_empty(_H2DMsgType.LOG_CLEAR)
self._read_empty(_D2HMsgType.LOG_REPLY)
def flash_storage_read(self, key):
self._write_header(_H2DMsgType.FLASH_READ_REQUEST)
self._write_string(key)
self._write_flush()
self._read_header()
self._read_expect(_D2HMsgType.FLASH_READ_REPLY)
return self._read_chunk(self._read_length)
def flash_storage_write(self, key, value):
self._write_header(_H2DMsgType.FLASH_WRITE_REQUEST)
self._write_string(key)
self._write_bytes(value)
self._write_flush()
self._read_header()
if self._read_type == _D2HMsgType.FLASH_ERROR_REPLY:
raise IOError("Flash storage is full")
else:
self._read_expect(_D2HMsgType.FLASH_OK_REPLY)
def flash_storage_erase(self):
self._write_empty(_H2DMsgType.FLASH_ERASE_REQUEST)
self._read_empty(_D2HMsgType.FLASH_OK_REPLY)
def flash_storage_remove(self, key):
self._write_header(_H2DMsgType.FLASH_REMOVE_REQUEST)
self._write_string(key)
self._write_flush()
self._read_empty(_D2HMsgType.FLASH_OK_REPLY)
def load(self, kernel_library):
self._write_header(_H2DMsgType.LOAD_LIBRARY)
self._write_chunk(kernel_library)
self._write_flush()
self._read_empty(_D2HMsgType.LOAD_COMPLETED)
def run(self):
self._write_empty(_H2DMsgType.RUN_KERNEL)
logger.debug("running kernel")
_rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, object_map):
tag = chr(self._read_int8())
if tag == "\x00":
return self._rpc_sentinel
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(object_map) for _ in range(length))
elif tag == "n":
return None
elif tag == "b":
return bool(self._read_int8())
elif tag == "i":
return self._read_int32()
elif tag == "I":
return self._read_int64()
elif tag == "f":
return self._read_float64()
elif tag == "F":
numerator = self._read_int64()
denominator = self._read_int64()
return Fraction(numerator, denominator)
elif tag == "s":
return self._read_string()
elif tag == "l":
length = self._read_int32()
return [self._receive_rpc_value(object_map) for _ in range(length)]
elif tag == "r":
start = self._receive_rpc_value(object_map)
stop = self._receive_rpc_value(object_map)
step = self._receive_rpc_value(object_map)
return range(start, stop, step)
elif tag == "o":
present = self._read_int8()
if present:
return self._receive_rpc_value(object_map)
elif tag == "O":
return object_map.retrieve(self._read_int32())
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, object_map):
args = []
while True:
value = self._receive_rpc_value(object_map)
if value is self._rpc_sentinel:
return args
args.append(value)
def _skip_rpc_value(self, tags):
tag = tags.pop(0)
if tag == "t":
length = tags.pop(0)
for _ in range(length):
self._skip_rpc_value(tags)
elif tag == "l":
self._skip_rpc_value(tags)
elif tag == "r":
self._skip_rpc_value(tags)
else:
pass
def _send_rpc_value(self, tags, value, root, function):
def check(cond, expected):
if not cond:
raise RPCReturnValueError(
"type mismatch: cannot serialize {value} as {type}"
" ({function} has returned {root})".format(
value=repr(value), type=expected(),
function=function, root=root))
tag = chr(tags.pop(0))
if tag == "t":
length = tags.pop(0)
check(isinstance(value, tuple) and length == len(value),
lambda: "tuple of {}".format(length))
for elt in value:
self._send_rpc_value(tags, elt, root, function)
elif tag == "n":
check(value is None,
lambda: "None")
elif tag == "b":
check(isinstance(value, bool),
lambda: "bool")
self._write_int8(value)
elif tag == "i":
check(isinstance(value, int) and (-2**31 < value < 2**31-1),
lambda: "32-bit int")
self._write_int32(value)
elif tag == "I":
check(isinstance(value, int) and (-2**63 < value < 2**63-1),
lambda: "64-bit int")
self._write_int64(value)
elif tag == "f":
check(isinstance(value, float),
lambda: "float")
self._write_float64(value)
elif tag == "F":
check(isinstance(value, Fraction) and
(-2**63 < value.numerator < 2**63-1) and
(-2**63 < value.denominator < 2**63-1),
lambda: "64-bit Fraction")
self._write_int64(value.numerator)
self._write_int64(value.denominator)
elif tag == "s":
check(isinstance(value, str) and "\x00" not in value,
lambda: "str")
self._write_string(value)
elif tag == "l":
check(isinstance(value, list),
lambda: "list")
self._write_int32(len(value))
for elt in value:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "r":
check(isinstance(value, range),
lambda: "range")
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.start, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.stop, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.step, root, function)
tags = tags_copy
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, object_map):
service = self._read_int32()
args = self._receive_rpc_args(object_map)
return_tags = self._read_bytes()
logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
try:
result = object_map.retrieve(service)(*args)
logger.debug("rpc service: %d %r == %r", service, args, result)
self._write_header(_H2DMsgType.RPC_REPLY)
self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result,
object_map.retrieve(service))
self._write_flush()
except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION)
self._write_string(exn.name)
self._write_string(exn.message)
for index in range(3):
self._write_int64(exn.param[index])
self._write_string(exn.filename)
self._write_int32(exn.line)
self._write_int32(exn.column)
self._write_string(exn.function)
self._write_flush()
except Exception as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION)
self._write_string(type(exn).__name__)
self._write_string(str(exn))
for index in range(3):
self._write_int64(0)
(_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2)
self._write_string(filename)
self._write_int32(line)
self._write_int32(-1) # column not known
self._write_string(function)
self._write_flush()
def _serve_exception(self, symbolizer):
name = self._read_string()
message = self._read_string()
params = [self._read_int64() for _ in range(3)]
filename = self._read_string()
line = self._read_int32()
column = self._read_int32()
function = self._read_string()
backtrace = [self._read_int32() for _ in range(self._read_int32())]
traceback = list(reversed(symbolizer(backtrace))) + \
[(filename, line, column, function, None)]
raise core_language.ARTIQException(name, message, params, traceback)
def serve(self, object_map, symbolizer):
while True:
self._read_header()
if self._read_type == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(object_map)
elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(symbolizer)
else:
self._read_expect(_D2HMsgType.KERNEL_FINISHED)
return