artiq/artiq/coredevice/comm_kernel.py
2020-09-04 13:49:22 +08:00

637 lines
21 KiB
Python

import struct
import logging
import traceback
import numpy
import socket
from enum import Enum
from fractions import Fraction
from collections import namedtuple
from artiq.coredevice import exceptions
from artiq import __version__ as software_version
logger = logging.getLogger(__name__)
class Request(Enum):
SystemInfo = 3
LoadKernel = 5
RunKernel = 6
RPCReply = 7
RPCException = 8
class Reply(Enum):
SystemInfo = 2
LoadCompleted = 5
LoadFailed = 6
KernelFinished = 7
KernelStartupFailed = 8
KernelException = 9
RPCRequest = 10
WatchdogExpired = 14
ClockFailure = 15
class UnsupportedDevice(Exception):
pass
class LoadError(Exception):
pass
class RPCReturnValueError(ValueError):
pass
RPCKeyword = namedtuple('RPCKeyword', ['name', 'value'])
def _receive_fraction(kernel, embedding_map):
numerator = kernel._read_int64()
denominator = kernel._read_int64()
return Fraction(numerator, denominator)
def _receive_list(kernel, embedding_map):
length = kernel._read_int32()
tag = chr(kernel._read_int8())
if tag == "b":
buffer = kernel._read(length)
return list(buffer)
elif tag == "i":
buffer = kernel._read(4 * length)
return list(struct.unpack(">%sl" % length, buffer))
elif tag == "I":
buffer = kernel._read(8 * length)
return list(struct.unpack(">%sq" % length, buffer))
elif tag == "f":
buffer = kernel._read(8 * length)
return list(struct.unpack(">%sd" % length, buffer))
else:
fn = receivers[tag]
elems = []
for _ in range(length):
# discard tag, as our device would still send the tag for each
# non-primitive elements.
kernel._read_int8()
item = fn(kernel, embedding_map)
elems.append(item)
return elems
def _receive_array(kernel, embedding_map):
num_dims = kernel._read_int8()
shape = tuple(kernel._read_int32() for _ in range(num_dims))
tag = chr(kernel._read_int8())
fn = receivers[tag]
length = numpy.prod(shape)
if tag == "b":
buffer = kernel._read(length)
elems = numpy.ndarray((length, ), 'B', buffer)
elif tag == "i":
buffer = kernel._read(4 * length)
elems = numpy.ndarray((length, ), '>i4', buffer)
elif tag == "I":
buffer = kernel._read(8 * length)
elems = numpy.ndarray((length, ), '>i8', buffer)
elif tag == "f":
buffer = kernel._read(8 * length)
elems = numpy.ndarray((length, ), '>d', buffer)
else:
fn = receivers[tag]
elems = []
for _ in range(numpy.prod(shape)):
# discard the tag
kernel._read_int8()
item = fn(kernel, embedding_map)
elems.append(item)
elems = numpy.array(elems)
return elems.reshape(shape)
def _receive_range(kernel, embedding_map):
start = kernel._receive_rpc_value(embedding_map)
stop = kernel._receive_rpc_value(embedding_map)
step = kernel._receive_rpc_value(embedding_map)
return range(start, stop, step)
def _receive_keyword(kernel, embedding_map):
name = kernel._read_string()
value = kernel._receive_rpc_value(embedding_map)
return RPCKeyword(name, value)
receivers = {
"\x00": lambda kernel, embedding_map: kernel._rpc_sentinel,
"t": lambda kernel, embedding_map:
tuple(kernel._receive_rpc_value(embedding_map)
for _ in range(kernel._read_int8())),
"n": lambda kernel, embedding_map: None,
"b": lambda kernel, embedding_map: bool(kernel._read_int8()),
"i": lambda kernel, embedding_map: numpy.int32(kernel._read_int32()),
"I": lambda kernel, embedding_map: numpy.int64(kernel._read_int64()),
"f": lambda kernel, embedding_map: kernel._read_float64(),
"s": lambda kernel, embedding_map: kernel._read_string(),
"B": lambda kernel, embedding_map: kernel._read_bytes(),
"A": lambda kernel, embedding_map: kernel._read_bytes(),
"O": lambda kernel, embedding_map:
embedding_map.retrieve_object(kernel._read_int32()),
"F": _receive_fraction,
"l": _receive_list,
"a": _receive_array,
"r": _receive_range,
"k": _receive_keyword
}
class CommKernelDummy:
def __init__(self):
pass
def load(self, kernel_library):
pass
def run(self):
pass
def serve(self, embedding_map, symbolizer, demangler):
pass
def check_system_info(self):
pass
class CommKernel:
warned_of_mismatch = False
def __init__(self, host, port=1381):
self._read_type = None
self.host = host
self.port = port
self.read_buffer = bytearray()
self.write_buffer = bytearray()
self.unpack_int32 = struct.Struct(">l").unpack
self.unpack_int64 = struct.Struct(">q").unpack
self.unpack_float64 = struct.Struct(">d").unpack
self.pack_header = struct.Struct(">lB").pack
self.pack_int32 = struct.Struct(">l").pack
self.pack_int64 = struct.Struct(">q").pack
self.pack_float64 = struct.Struct(">d").pack
def open(self):
if hasattr(self, "socket"):
return
self.socket = socket.create_connection((self.host, self.port))
logger.debug("connected to %s:%d", self.host, self.port)
self.socket.sendall(b"ARTIQ coredev\n")
def close(self):
if not hasattr(self, "socket"):
return
self.socket.close()
del self.socket
logger.debug("disconnected")
#
# Reader interface
#
def _read(self, length):
# cache the reads to avoid frequent call to recv
while len(self.read_buffer) < length:
# the number is just the maximum amount
# when there is not much data, it would return earlier
diff = length - len(self.read_buffer)
flag = 0
if diff > 8192:
flag |= socket.MSG_WAITALL
self.read_buffer += self.socket.recv(8192, flag)
result = self.read_buffer[:length]
self.read_buffer = self.read_buffer[length:]
return result
def _read_header(self):
self.open()
# Wait for a synchronization sequence, 5a 5a 5a 5a.
sync_count = 0
while sync_count < 4:
sync_byte = self._read(1)[0]
if sync_byte == 0x5a:
sync_count += 1
else:
sync_count = 0
# Read message header.
raw_type = self._read(1)[0]
self._read_type = Reply(raw_type)
logger.debug("receiving message: type=%r",
self._read_type)
def _read_expect(self, ty):
if self._read_type != ty:
raise IOError("Incorrect reply from device: {} (expected {})".
format(self._read_type, ty))
def _read_empty(self, ty):
self._read_header()
self._read_expect(ty)
def _read_int8(self):
return self._read(1)[0]
def _read_int32(self):
(value, ) = self.unpack_int32(self._read(4))
return value
def _read_int64(self):
(value, ) = self.unpack_int64(self._read(8))
return value
def _read_float64(self):
(value, ) = self.unpack_float64(self._read(8))
return value
def _read_bool(self):
return True if self._read_int8() else False
def _read_bytes(self):
return self._read(self._read_int32())
def _read_string(self):
return self._read_bytes().decode("utf-8")
#
# Writer interface
#
def _write(self, data):
self.write_buffer += data
# if the buffer is already pretty large, send it
# the block size is arbitrary, tuning it may improve performance
if len(self.write_buffer) > 4096:
self._flush()
def _flush(self):
self.socket.sendall(self.write_buffer)
self.write_buffer.clear()
def _write_header(self, ty):
self.open()
logger.debug("sending message: type=%r", ty)
# Write synchronization sequence and header.
self._write(self.pack_header(0x5a5a5a5a, ty.value))
def _write_empty(self, ty):
self._write_header(ty)
def _write_chunk(self, chunk):
self._write(chunk)
def _write_int8(self, value):
self._write(value)
def _write_int32(self, value):
self._write(self.pack_int32(value))
def _write_int64(self, value):
self._write(self.pack_int64(value))
def _write_float64(self, value):
self._write(self.pack_float64(value))
def _write_bool(self, value):
self._write(b'\x01' if value else b'\x00')
def _write_bytes(self, value):
self._write_int32(len(value))
self._write(value)
def _write_string(self, value):
self._write_bytes(value.encode("utf-8"))
#
# Exported APIs
#
def check_system_info(self):
self._write_empty(Request.SystemInfo)
self._flush()
self._read_header()
self._read_expect(Reply.SystemInfo)
runtime_id = self._read(4)
if runtime_id == b"AROR":
gateware_version = self._read_string().split(";")[0]
if gateware_version != software_version and not self.warned_of_mismatch:
logger.warning("Mismatch between gateware (%s) "
"and software (%s) versions",
gateware_version, software_version)
CommKernel.warned_of_mismatch = True
finished_cleanly = self._read_bool()
if not finished_cleanly:
logger.warning("Previous kernel did not cleanly finish")
elif runtime_id == b"ARZQ":
pass
else:
raise UnsupportedDevice("Unsupported runtime ID: {}"
.format(runtime_id))
def load(self, kernel_library):
self._write_header(Request.LoadKernel)
self._write_bytes(kernel_library)
self._flush()
self._read_header()
if self._read_type == Reply.LoadFailed:
raise LoadError(self._read_string())
else:
self._read_expect(Reply.LoadCompleted)
def run(self):
self._write_empty(Request.RunKernel)
self._flush()
logger.debug("running kernel")
_rpc_sentinel = object()
# See rpc_proto.rs and compiler/ir.py:rpc_tag.
def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8())
if tag in receivers:
return receivers.get(tag)(self, embedding_map)
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, embedding_map):
args, kwargs = [], {}
while True:
value = self._receive_rpc_value(embedding_map)
if value is self._rpc_sentinel:
return args, kwargs
elif isinstance(value, RPCKeyword):
kwargs[value.name] = value.value
else:
args.append(value)
def _skip_rpc_value(self, tags):
tag = chr(tags.pop(0))
if tag == "t":
length = tags.pop(0)
for _ in range(length):
self._skip_rpc_value(tags)
elif tag == "l":
self._skip_rpc_value(tags)
elif tag == "r":
self._skip_rpc_value(tags)
else:
pass
def _send_rpc_value(self, tags, value, root, function):
def check(cond, expected):
if not cond:
raise RPCReturnValueError(
"type mismatch: cannot serialize {value} as {type}"
" ({function} has returned {root})".format(
value=repr(value), type=expected(),
function=function, root=root))
tag = chr(tags.pop(0))
if tag == "t":
length = tags.pop(0)
check(isinstance(value, tuple) and length == len(value),
lambda: "tuple of {}".format(length))
for elt in value:
self._send_rpc_value(tags, elt, root, function)
elif tag == "n":
check(value is None,
lambda: "None")
elif tag == "b":
check(isinstance(value, bool),
lambda: "bool")
self._write_bool(value)
elif tag == "i":
check(isinstance(value, (int, numpy.int32)) and
(-2**31 < value < 2**31-1),
lambda: "32-bit int")
self._write_int32(value)
elif tag == "I":
check(isinstance(value, (int, numpy.int32, numpy.int64)) and
(-2**63 < value < 2**63-1),
lambda: "64-bit int")
self._write_int64(value)
elif tag == "f":
check(isinstance(value, float),
lambda: "float")
self._write_float64(value)
elif tag == "F":
check(isinstance(value, Fraction) and
(-2**63 < value.numerator < 2**63-1) and
(-2**63 < value.denominator < 2**63-1),
lambda: "64-bit Fraction")
self._write_int64(value.numerator)
self._write_int64(value.denominator)
elif tag == "s":
check(isinstance(value, str) and "\x00" not in value,
lambda: "str")
self._write_string(value)
elif tag == "B":
check(isinstance(value, bytes),
lambda: "bytes")
self._write_bytes(value)
elif tag == "A":
check(isinstance(value, bytearray),
lambda: "bytearray")
self._write_bytes(value)
elif tag == "l":
check(isinstance(value, list),
lambda: "list")
self._write_int32(len(value))
tag_element = chr(tags[0])
if tag_element == "b":
self._write(bytes(value))
elif tag_element == "i":
self._write(struct.pack(">%sl" % len(value), *value))
elif tag_element == "I":
self._write(struct.pack(">%sq" % len(value), *value))
elif tag_element == "f":
self._write(struct.pack(">%sd" % len(value), *value))
else:
for elt in value:
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "a":
check(isinstance(value, numpy.ndarray),
lambda: "numpy.ndarray")
num_dims = tags.pop(0)
check(num_dims == len(value.shape),
lambda: "{}-dimensional numpy.ndarray".format(num_dims))
for s in value.shape:
self._write_int32(s)
tag_element = chr(tags[0])
if tag_element == "b":
self._write(value.reshape((-1,), order="C").tobytes())
elif tag_element == "i":
array = value.reshape((-1,), order="C").astype('>i4')
self._write(array.tobytes())
elif tag_element == "I":
array = value.reshape((-1,), order="C").astype('>i8')
self._write(array.tobytes())
elif tag_element == "f":
array = value.reshape((-1,), order="C").astype('>d')
self._write(array.tobytes())
else:
for elt in value.reshape((-1,), order="C"):
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, elt, root, function)
self._skip_rpc_value(tags)
elif tag == "r":
check(isinstance(value, range),
lambda: "range")
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.start, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.stop, root, function)
tags_copy = bytearray(tags)
self._send_rpc_value(tags_copy, value.step, root, function)
tags = tags_copy
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _truncate_message(self, msg, limit=4096):
if len(msg) > limit:
return msg[0:limit] + "... (truncated)"
else:
return msg
def _serve_rpc(self, embedding_map):
is_async = self._read_bool()
service_id = self._read_int32()
args, kwargs = self._receive_rpc_args(embedding_map)
return_tags = self._read_bytes()
if service_id == 0:
def service(obj, attr, value): return setattr(obj, attr, value)
else:
service = embedding_map.retrieve_object(service_id)
logger.debug("rpc service: [%d]%r%s %r %r -> %s", service_id, service,
(" (async)" if is_async else ""), args, kwargs, return_tags)
if is_async:
service(*args, **kwargs)
return
try:
result = service(*args, **kwargs)
logger.debug("rpc service: %d %r %r = %r",
service_id, args, kwargs, result)
self._write_header(Request.RPCReply)
self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags),
result, result, service)
self._flush()
except RPCReturnValueError as exn:
raise
except Exception as exn:
logger.debug("rpc service: %d %r %r ! %r",
service_id, args, kwargs, exn)
self._write_header(Request.RPCException)
if hasattr(exn, "artiq_core_exception"):
exn = exn.artiq_core_exception
self._write_string(exn.name)
self._write_string(self._truncate_message(exn.message))
for index in range(3):
self._write_int64(exn.param[index])
filename, line, column, function = exn.traceback[-1]
self._write_string(filename)
self._write_int32(line)
self._write_int32(column)
self._write_string(function)
else:
exn_type = type(exn)
if exn_type in (ZeroDivisionError, ValueError, IndexError, RuntimeError) or \
hasattr(exn, "artiq_builtin"):
self._write_string("0:{}".format(exn_type.__name__))
else:
exn_id = embedding_map.store_object(exn_type)
self._write_string("{}:{}.{}".format(exn_id,
exn_type.__module__,
exn_type.__qualname__))
self._write_string(self._truncate_message(str(exn)))
for index in range(3):
self._write_int64(0)
tb = traceback.extract_tb(exn.__traceback__, 2)
if len(tb) == 2:
(_, (filename, line, function, _), ) = tb
elif len(tb) == 1:
((filename, line, function, _), ) = tb
else:
assert False
self._write_string(filename)
self._write_int32(line)
self._write_int32(-1) # column not known
self._write_string(function)
self._flush()
def _serve_exception(self, embedding_map, symbolizer, demangler):
name = self._read_string()
message = self._read_string()
params = [self._read_int64() for _ in range(3)]
filename = self._read_string()
line = self._read_int32()
column = self._read_int32()
function = self._read_string()
backtrace = [self._read_int32() for _ in range(self._read_int32())]
traceback = list(reversed(symbolizer(backtrace))) + \
[(filename, line, column, *demangler([function]), None)]
core_exn = exceptions.CoreException(name, message, params, traceback)
if core_exn.id == 0:
python_exn_type = getattr(exceptions, core_exn.name.split('.')[-1])
else:
python_exn_type = embedding_map.retrieve_object(core_exn.id)
python_exn = python_exn_type(message.format(*params))
python_exn.artiq_core_exception = core_exn
raise python_exn
def serve(self, embedding_map, symbolizer, demangler):
while True:
self._read_header()
if self._read_type == Reply.RPCRequest:
self._serve_rpc(embedding_map)
elif self._read_type == Reply.KernelException:
self._serve_exception(embedding_map, symbolizer, demangler)
elif self._read_type == Reply.WatchdogExpired:
raise exceptions.WatchdogExpired
elif self._read_type == Reply.ClockFailure:
raise exceptions.ClockFailure
else:
self._read_expect(Reply.KernelFinished)
return