2
0
mirror of https://github.com/m-labs/artiq.git synced 2025-01-08 01:53:34 +08:00
artiq/artiq/coredevice/comm_kernel.py
2019-03-31 17:10:27 +01:00

513 lines
16 KiB
Python

import struct
import logging
import traceback
import numpy
from enum import Enum
from fractions import Fraction
from collections import namedtuple
from artiq.coredevice import exceptions
from artiq.coredevice.comm import initialize_connection
from artiq import __version__ as software_version
logger = logging.getLogger(__name__)
class Request(Enum):
SystemInfo = 3
LoadKernel = 5
RunKernel = 6
RPCReply = 7
RPCException = 8
class Reply(Enum):
SystemInfo = 2
LoadCompleted = 5
LoadFailed = 6
KernelFinished = 7
KernelStartupFailed = 8
KernelException = 9
RPCRequest = 10
WatchdogExpired = 14
ClockFailure = 15
class UnsupportedDevice(Exception):
pass
class LoadError(Exception):
pass
class RPCReturnValueError(ValueError):
pass
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
class CommKernelDummy:
def __init__(self):
pass
def load(self, kernel_library):
pass
def run(self):
pass
def serve(self, embedding_map, symbolizer, demangler):
pass
def check_system_info(self):
pass
class CommKernel:
warned_of_mismatch = False
def __init__(self, host, port=1381):
self._read_type = None
self.host = host
self.port = port
def open(self, **kwargs):
if hasattr(self, "socket"):
return
self.socket = initialize_connection(self.host, self.port, **kwargs)
self.socket.sendall(b"ARTIQ coredev\n")
def close(self):
if not hasattr(self, "socket"):
return
self.socket.close()
del self.socket
logger.debug("disconnected")
def read(self, length):
r = bytes()
while len(r) < length:
rn = self.socket.recv(min(8192, length - len(r)))
if not rn:
raise ConnectionResetError("Connection closed")
r += rn
return r
def write(self, data):
self.socket.sendall(data)
#
# Reader interface
#
def _read_header(self):
self.open()
# Wait for a synchronization sequence, 5a 5a 5a 5a.
sync_count = 0
while sync_count < 4:
(sync_byte, ) = struct.unpack("B", self.read(1))
if sync_byte == 0x5a:
sync_count += 1
else:
sync_count = 0
# Read message header.
(raw_type, ) = struct.unpack("B", self.read(1))
self._read_type = Reply(raw_type)
logger.debug("receiving message: type=%r",
self._read_type)
def _read_expect(self, ty):
if self._read_type != ty:
raise IOError("Incorrect reply from device: {} (expected {})".
format(self._read_type, ty))
def _read_empty(self, ty):
self._read_header()
self._read_expect(ty)
def _read_chunk(self, length):
return self.read(length)
def _read_int8(self):
(value, ) = struct.unpack("B", self._read_chunk(1))
return value
def _read_int32(self):
(value, ) = struct.unpack(">l", self._read_chunk(4))
return value
def _read_int64(self):
(value, ) = struct.unpack(">q", self._read_chunk(8))
return value
def _read_float64(self):
(value, ) = struct.unpack(">d", self._read_chunk(8))
return value
def _read_bool(self):
return True if self._read_int8() else False
def _read_bytes(self):
return self._read_chunk(self._read_int32())
def _read_string(self):
return self._read_bytes().decode("utf-8")
#
# Writer interface
#
def _write_header(self, ty):
self.open()
logger.debug("sending message: type=%r", ty)
# Write synchronization sequence and header.
self.write(struct.pack(">lB", 0x5a5a5a5a, ty.value))
def _write_empty(self, ty):
self._write_header(ty)
def _write_chunk(self, chunk):
self.write(chunk)
def _write_int8(self, value):
self.write(struct.pack("B", value))
def _write_int32(self, value):
self.write(struct.pack(">l", value))
def _write_int64(self, value):
self.write(struct.pack(">q", value))
def _write_float64(self, value):
self.write(struct.pack(">d", value))
def _write_bool(self, value):
self.write(struct.pack("B", value))
def _write_bytes(self, value):
self._write_int32(len(value))
self.write(value)
def _write_string(self, value):
self._write_bytes(value.encode("utf-8"))
#
# Exported APIs
#
def reset_session(self):
self.write(struct.pack(">ll", 0x5a5a5a5a, 0))
def check_system_info(self):
self._write_empty(Request.SystemInfo)
self._read_header()
self._read_expect(Reply.SystemInfo)
runtime_id = self._read_chunk(4)
if runtime_id != b"AROR":
raise UnsupportedDevice("Unsupported runtime ID: {}"
.format(runtime_id))
gateware_version = self._read_string().split(";")[0]
if gateware_version != software_version and not self.warned_of_mismatch:
logger.warning("Mismatch between gateware (%s) "
"and software (%s) versions",
gateware_version, software_version)
CommKernel.warned_of_mismatch = True
finished_cleanly = self._read_bool()
if not finished_cleanly:
logger.warning("Previous kernel did not cleanly finish")
def load(self, kernel_library):
self._write_header(Request.LoadKernel)
self._write_bytes(kernel_library)
self._read_header()
if self._read_type == Reply.LoadFailed:
raise LoadError(self._read_string())
else:
self._read_expect(Reply.LoadCompleted)
def run(self):
self._write_empty(Request.RunKernel)
logger.debug("running kernel")
_rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8())
if tag == "\x00":
return self._rpc_sentinel
elif tag == "t":
length = self._read_int8()
return tuple(self._receive_rpc_value(embedding_map) for _ in range(length))
elif tag == "n":
return None
elif tag == "b":
return bool(self._read_int8())
elif tag == "i":
return numpy.int32(self._read_int32())
elif tag == "I":
return numpy.int64(self._read_int64())
elif tag == "f":
return self._read_float64()
elif tag == "F":
numerator = self._read_int64()
denominator = self._read_int64()
return Fraction(numerator, denominator)
elif tag == "s":
return self._read_string()
elif tag == "B":
return self._read_bytes()
elif tag == "A":
return self._read_bytes()
elif tag == "l":
length = self._read_int32()
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "a":
length = self._read_int32()
return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)])
elif tag == "r":
start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(embedding_map)
step = self._receive_rpc_value(embedding_map)
return range(start, stop, step)
elif tag == "k":
name = self._read_string()
value = self._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
elif tag == "O":
return embedding_map.retrieve_object(self._read_int32())
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, embedding_map):
args, kwargs = [], {}
while True:
value = self._receive_rpc_value(embedding_map)
if value is self._rpc_sentinel:
return args, kwargs
elif isinstance(value, RPCKeyword):
kwargs[value.name] = value.value
else:
args.append(value)
def _skip_rpc_value(self, tags):
tag = chr(tags.pop(0))
if tag == "t":
length = tags.pop(0)
for _ in range(length):
self._skip_rpc_value(tags)
elif tag == "l":
self._skip_rpc_value(tags)
elif tag == "r":
self._skip_rpc_value(tags)
else:
pass
def _send_rpc_value(self, tags, value, root, function):
def check(cond, expected):
if not cond:
raise RPCReturnValueError(
"type mismatch: cannot serialize {value} as {type}"
" ({function} has returned {root})".format(
value=repr(value), type=expected(),
function=function, root=root))
tag = chr(tags.pop(0))
if tag == "t":
length = tags.pop(0)
check(isinstance(value, tuple) and length == len(value),
lambda: "tuple of {}".format(length))
for elt in value:
self._send_rpc_value(tags, elt, root, function)
elif tag == "n":
check(value is None,
lambda: "None")
elif tag == "b":
check(isinstance(value, bool),
lambda: "bool")
self._write_int8(value)
elif tag == "i":
check(isinstance(value, (int, numpy.int32)) and
(-2**31 < value < 2**31-1),
lambda: "32-bit int")
self._write_int32(value)
elif tag == "I":
check(isinstance(value, (int, numpy.int32, numpy.int64)) and
(-2**63 < value < 2**63-1),
lambda: "64-bit int")
self._write_int64(value)
elif tag == "f":
check(isinstance(value, float),
lambda: "float")
self._write_float64(value)
elif tag == "F":
check(isinstance(value, Fraction) and
(-2**63 < value.numerator < 2**63-1) and
(-2**63 < value.denominator < 2**63-1),
lambda: "64-bit Fraction")
self._write_int64(value.numerator)
self._write_int64(value.denominator)
elif tag == "s":
check(isinstance(value, str) and "\x00" not in value,
lambda: "str")
self._write_string(value)
elif tag == "B":
check(isinstance(value, bytes),
lambda: "bytes")
self._write_bytes(value)
elif tag == "A":
check(isinstance(value, bytearray),
lambda: "bytearray")
self._write_bytes(value)
elif tag == "l":
check(isinstance(value, list),
lambda: "list")
self._write_int32(len(value))
for elt in value:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "r":
check(isinstance(value, range),
lambda: "range")
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.start, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.stop, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.step, root, function)
tags = tags_copy
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _truncate_message(self, msg, limit=4096):
if len(msg) > limit:
return msg[0:limit] + "... (truncated)"
else:
return msg
def _serve_rpc(self, embedding_map):
is_async = self._read_bool()
service_id = self._read_int32()
args, kwargs = self._receive_rpc_args(embedding_map)
return_tags = self._read_bytes()
if service_id is 0:
service = lambda obj, attr, value: setattr(obj, attr, value)
else:
service = embedding_map.retrieve_object(service_id)
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
(" (async)" if is_async else ""), args, kwargs, return_tags)
if is_async:
service(*args, **kwargs)
return
try:
result = service(*args, **kwargs)
logger.debug("rpc service: %d %r %r = %r", service_id, args, kwargs, result)
self._write_header(Request.RPCReply)
self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result, service)
except RPCReturnValueError as exn:
raise
except Exception as exn:
logger.debug("rpc service: %d %r %r ! %r", service_id, args, kwargs, exn)
self._write_header(Request.RPCException)
if hasattr(exn, "artiq_core_exception"):
exn = exn.artiq_core_exception
self._write_string(exn.name)
self._write_string(self._truncate_message(exn.message))
for index in range(3):
self._write_int64(exn.param[index])
filename, line, column, function = exn.traceback[-1]
self._write_string(filename)
self._write_int32(line)
self._write_int32(column)
self._write_string(function)
else:
exn_type = type(exn)
if exn_type in (ZeroDivisionError, ValueError, IndexError, RuntimeError) or \
hasattr(exn, "artiq_builtin"):
self._write_string("0:{}".format(exn_type.__name__))
else:
exn_id = embedding_map.store_object(exn_type)
self._write_string("{}:{}.{}".format(exn_id,
exn_type.__module__,
exn_type.__qualname__))
self._write_string(self._truncate_message(str(exn)))
for index in range(3):
self._write_int64(0)
tb = traceback.extract_tb(exn.__traceback__, 2)
if len(tb) == 2:
(_, (filename, line, function, _), ) = tb
elif len(tb) == 1:
((filename, line, function, _), ) = tb
else:
assert False
self._write_string(filename)
self._write_int32(line)
self._write_int32(-1) # column not known
self._write_string(function)
def _serve_exception(self, embedding_map, symbolizer, demangler):
name = self._read_string()
message = self._read_string()
params = [self._read_int64() for _ in range(3)]
filename = self._read_string()
line = self._read_int32()
column = self._read_int32()
function = self._read_string()
backtrace = [self._read_int32() for _ in range(self._read_int32())]
traceback = list(reversed(symbolizer(backtrace))) + \
[(filename, line, column, *demangler([function]), None)]
core_exn = exceptions.CoreException(name, message, params, traceback)
if core_exn.id == 0:
python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1])
else:
python_exn_type = embedding_map.retrieve_object(core_exn.id)
python_exn = python_exn_type(message.format(*params))
python_exn.artiq_core_exception = core_exn
raise python_exn
def serve(self, embedding_map, symbolizer, demangler):
while True:
self._read_header()
if self._read_type == Reply.RPCRequest:
self._serve_rpc(embedding_map)
elif self._read_type == Reply.KernelException:
self._serve_exception(embedding_map, symbolizer, demangler)
elif self._read_type == Reply.WatchdogExpired:
raise exceptions.WatchdogExpired
elif self._read_type == Reply.ClockFailure:
raise exceptions.ClockFailure
else:
self._read_expect(Reply.KernelFinished)
return