diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py new file mode 100644 index 000000000..76548c992 --- /dev/null +++ b/artiq/coredevice/comm_generic.py @@ -0,0 +1,202 @@ +import struct +import zlib +import logging +from enum import Enum +from fractions import Fraction + +from artiq.language import units +from artiq.coredevice.runtime import Environment +from artiq.coredevice import runtime_exceptions +from artiq.language import core as core_language +from artiq.coredevice.rpc_wrapper import RPCWrapper + + +logger = logging.getLogger(__name__) + + +class _H2DMsgType(Enum): + LINK_MESSAGE = 1 + + REQUEST_IDENT = 2 + SWITCH_CLOCK = 3 + + LOAD_OBJECT = 4 + RUN_KERNEL = 5 + + +class _D2HMsgType(Enum): + MESSAGE_UNRECOGNIZED = 1 + LOG = 2 + + IDENT = 3 + CLOCK_SWITCH_COMPLETED = 4 + CLOCK_SWITCH_FAILED = 5 + + OBJECT_LOADED = 6 + OBJECT_INCORRECT_LENGTH = 7 + OBJECT_CRC_FAILED = 8 + OBJECT_UNRECOGNIZED = 9 + + KERNEL_FINISHED = 10 + KERNEL_STARTUP_FAILED = 11 + KERNEL_EXCEPTION = 12 + + RPC_REQUEST = 13 + + +class UnsupportedDevice(Exception): + pass + + +class CommGeneric: + # methods for derived classes to implement + def open(self): + """Opens the communication channel. + Must do nothing if already opened.""" + raise NotImplementedError + + def close(self): + """Closes the communication channel. + Must do nothing if already closed.""" + raise NotImplementedError + + def read(self, length): + """Reads exactly length bytes from the communication channel. + The channel is assumed to be opened.""" + raise NotImplementedError + + def write(self, data): + """Writes exactly length bytes to the communication channel. + The channel is assumed to be opened.""" + raise NotImplementedError + # + + def _read(self, length): + self.open() + return self.read(length) + + def _write(self, data): + self.open() + self.write(data) + + def _get_device_msg(self): + while True: + (reply, ) = struct.unpack("B", self._read(1)) + msg = _D2HMsgType(reply) + if msg == _D2HMsgType.LOG: + (length, ) = struct.unpack(">h", self._read(2)) + log_message = "" + for i in range(length): + (c, ) = struct.unpack("B", self._read(1)) + log_message += chr(c) + logger.info("DEVICE LOG: %s", log_message) + else: + logger.debug("message received: %r", msg) + return msg + + def get_runtime_env(self): + self._write(struct.pack(">lb", 0x5a5a5a5a, + _H2DMsgType.REQUEST_IDENT.value)) + msg = self._get_device_msg() + if msg != _D2HMsgType.IDENT: + raise IOError("Incorrect reply from device: {}".format(msg)) + (reply, ) = struct.unpack("B", self._read(1)) + runtime_id = chr(reply) + for i in range(3): + (reply, ) = struct.unpack("B", self._read(1)) + runtime_id += chr(reply) + if runtime_id != "AROR": + raise UnsupportedDevice("Unsupported runtime ID: {}" + .format(runtime_id)) + ref_freq_i, ref_freq_fn, ref_freq_fd = struct.unpack( + ">lBB", self._read(6)) + ref_freq = (ref_freq_i + Fraction(ref_freq_fn, ref_freq_fd))*units.Hz + ref_period = 1/ref_freq + logger.debug("environment ref_period: %s", ref_period) + return Environment(ref_period) + + def switch_clock(self, external): + self._write(struct.pack( + ">lbb", 0x5a5a5a5a, _H2DMsgType.SWITCH_CLOCK.value, + int(external))) + msg = self._get_device_msg() + if msg != _D2HMsgType.CLOCK_SWITCH_COMPLETED: + raise IOError("Incorrect reply from device: {}".format(msg)) + + def load(self, kcode): + self._write(struct.pack( + ">lblL", + 0x5a5a5a5a, _H2DMsgType.LOAD_OBJECT.value, + len(kcode), zlib.crc32(kcode))) + self._write(kcode) + msg = self._get_device_msg() + if msg != _D2HMsgType.OBJECT_LOADED: + raise IOError("Incorrect reply from device: "+str(msg)) + + def run(self, kname): + self._write(struct.pack( + ">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname))) + for c in kname: + self._write(struct.pack(">B", ord(c))) + logger.debug("running kernel: %s", kname) + + + + def _receive_rpc_values(self): + r = [] + while True: + type_tag = chr(struct.unpack(">B", self._read(1))[0]) + if type_tag == "\x00": + return r + if type_tag == "n": + r.append(None) + if type_tag == "b": + r.append(bool(struct.unpack(">B", self._read(1))[0])) + if type_tag == "i": + r.append(struct.unpack(">l", self._read(4))[0]) + if type_tag == "I": + r.append(struct.unpack(">q", self._read(8))[0]) + if type_tag == "f": + r.append(struct.unpack(">d", self._read(8))[0]) + if type_tag == "F": + n, d = struct.unpack(">qq", self._read(16)) + r.append(Fraction(n, d)) + if type_tag == "l": + r.append(self._receive_rpc_values()) + + def _serve_rpc(self, rpc_wrapper, rpc_map, user_exception_map): + rpc_num = struct.unpack(">h", self._read(2))[0] + args = self._receive_rpc_values() + logger.debug("rpc service: %d %r", rpc_num, args) + eid, r = rpc_wrapper.run_rpc( + user_exception_map, rpc_map[rpc_num], args) + self._write(struct.pack(">ll", eid, r)) + logger.debug("rpc service: %d %r == %r", rpc_num, args, r) + + def _serve_exception(self, rpc_wrapper, user_exception_map): + eid, p0, p1, p2 = struct.unpack(">lqqq", self._read(4+3*8)) + rpc_wrapper.filter_rpc_exception(eid) + if eid < core_language.first_user_eid: + exception = runtime_exceptions.exception_map[eid] + raise exception(self.core, p0, p1, p2) + else: + exception = user_exception_map[eid] + raise exception + + def serve(self, rpc_map, user_exception_map): + rpc_wrapper = RPCWrapper() + while True: + msg = self._get_device_msg() + if msg == _D2HMsgType.RPC_REQUEST: + self._serve_rpc(rpc_wrapper, rpc_map, user_exception_map) + elif msg == _D2HMsgType.KERNEL_EXCEPTION: + self._serve_exception(rpc_wrapper, user_exception_map) + elif msg == _D2HMsgType.KERNEL_FINISHED: + return + else: + raise IOError("Incorrect request from device: "+str(msg)) + + def send_link_message(self, data): + self._write(struct.pack( + ">lb", 0x5a5a5a5a, _H2DMsgType.LINK_MESSAGE.value)) + self._write(data) diff --git a/artiq/coredevice/comm_serial.py b/artiq/coredevice/comm_serial.py index 322813297..b12cc03dd 100644 --- a/artiq/coredevice/comm_serial.py +++ b/artiq/coredevice/comm_serial.py @@ -1,89 +1,59 @@ +import logging import serial import struct -import zlib -from enum import Enum -from fractions import Fraction -import logging -from artiq.language import core as core_language -from artiq.language import units +from artiq.coredevice.comm_generic import CommGeneric from artiq.language.db import * -from artiq.coredevice.runtime import Environment -from artiq.coredevice import runtime_exceptions -from artiq.coredevice.rpc_wrapper import RPCWrapper logger = logging.getLogger(__name__) -class UnsupportedDevice(Exception): - pass - - -class _H2DMsgType(Enum): - REQUEST_IDENT = 1 - LOAD_OBJECT = 2 - RUN_KERNEL = 3 - SET_BAUD_RATE = 4 - SWITCH_CLOCK = 5 - - -class _D2HMsgType(Enum): - LOG = 1 - MESSAGE_UNRECOGNIZED = 2 - IDENT = 3 - OBJECT_LOADED = 4 - INCORRECT_LENGTH = 5 - CRC_FAILED = 6 - OBJECT_UNRECOGNIZED = 7 - KERNEL_FINISHED = 8 - KERNEL_EXCEPTION = 9 - KERNEL_STARTUP_FAILED = 10 - RPC_REQUEST = 11 - CLOCK_SWITCH_COMPLETED = 12 - CLOCK_SWITCH_FAILED = 13 - - -def _write_exactly(f, data): - remaining = len(data) - pos = 0 - while remaining: - written = f.write(data[pos:]) - remaining -= written - pos += written - - -def _read_exactly(f, n): - r = bytes() - while(len(r) < n): - r += f.read(n - len(r)) - return r - - -class Comm(AutoDB): +class Comm(CommGeneric, AutoDB): class DBKeys: - serial_dev = Parameter("/dev/ttyUSB1") + serial_dev = Parameter() baud_rate = Parameter(115200) - def build(self): + def open(self): + if hasattr(self, "port"): + return self.port = serial.serial_for_url(self.serial_dev, baudrate=115200) self.port.flush() self.set_remote_baud(self.baud_rate) self.set_baud(self.baud_rate) - self.rpc_wrapper = RPCWrapper() + + def close(self): + if not hasattr(self, "port"): + return + self.set_remote_baud(115200) + self.port.close() + del self.port + + def read(self, length): + r = bytes() + while(len(r) < length): + r += self.port.read(length - len(r)) + return r + + def write(self, data): + remaining = len(data) + pos = 0 + while remaining: + written = self.port.write(data[pos:]) + remaining -= written + pos += written def set_baud(self, baud): self.port.baudrate = baud self.port.flush() - logger.debug("baud rate set to %d", baud) + logger.debug("local baud rate set to %d", baud) def set_remote_baud(self, baud): - _write_exactly(self.port, struct.pack( - ">lbl", 0x5a5a5a5a, _H2DMsgType.SET_BAUD_RATE.value, baud)) + self.send_link_message(struct.pack(">l", baud)) handshake = 0 fails = 0 while handshake < 4: - (recv, ) = struct.unpack("B", _read_exactly(self.port, 1)) + (recv, ) = struct.unpack("B", self.read(1)) if recv == 0x5a: handshake += 1 else: @@ -95,124 +65,4 @@ class Comm(AutoDB): fails += 1 if fails > 3: raise IOError("Baudrate ack failed") - self.set_baud(baud) - logger.debug("synchronized") - - def close(self): - self.set_remote_baud(115200) - self.port.close() - - def _get_device_msg(self): - while True: - (reply, ) = struct.unpack("B", _read_exactly(self.port, 1)) - msg = _D2HMsgType(reply) - if msg == _D2HMsgType.LOG: - (length, ) = struct.unpack(">h", _read_exactly(self.port, 2)) - log_message = "" - for i in range(length): - (c, ) = struct.unpack("B", _read_exactly(self.port, 1)) - log_message += chr(c) - logger.info("DEVICE LOG: %s", log_message) - else: - logger.debug("message received: %r", msg) - return msg - - def get_runtime_env(self): - _write_exactly(self.port, struct.pack( - ">lb", 0x5a5a5a5a, _H2DMsgType.REQUEST_IDENT.value)) - msg = self._get_device_msg() - if msg != _D2HMsgType.IDENT: - raise IOError("Incorrect reply from device: "+str(msg)) - (reply, ) = struct.unpack("B", _read_exactly(self.port, 1)) - runtime_id = chr(reply) - for i in range(3): - (reply, ) = struct.unpack("B", _read_exactly(self.port, 1)) - runtime_id += chr(reply) - if runtime_id != "AROR": - raise UnsupportedDevice("Unsupported runtime ID: "+runtime_id) - ref_freq_i, ref_freq_fn, ref_freq_fd = struct.unpack( - ">lBB", _read_exactly(self.port, 6)) - ref_freq = (ref_freq_i + Fraction(ref_freq_fn, ref_freq_fd))*units.Hz - ref_period = 1/ref_freq - logger.debug("environment ref_period: %s", ref_period) - return Environment(ref_period) - - def switch_clock(self, external): - _write_exactly(self.port, struct.pack( - ">lbb", 0x5a5a5a5a, _H2DMsgType.SWITCH_CLOCK.value, - int(external))) - msg = self._get_device_msg() - if msg != _D2HMsgType.CLOCK_SWITCH_COMPLETED: - raise IOError("Incorrect reply from device: "+str(msg)) - - def load(self, kcode): - _write_exactly(self.port, struct.pack( - ">lblL", - 0x5a5a5a5a, _H2DMsgType.LOAD_OBJECT.value, - len(kcode), zlib.crc32(kcode))) - _write_exactly(self.port, kcode) - msg = self._get_device_msg() - if msg != _D2HMsgType.OBJECT_LOADED: - raise IOError("Incorrect reply from device: "+str(msg)) - - def run(self, kname): - _write_exactly(self.port, struct.pack( - ">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname))) - for c in kname: - _write_exactly(self.port, struct.pack(">B", ord(c))) - logger.debug("running kernel: %s", kname) - - def _receive_rpc_values(self): - r = [] - while True: - type_tag = chr(struct.unpack(">B", _read_exactly(self.port, 1))[0]) - if type_tag == "\x00": - return r - if type_tag == "n": - r.append(None) - if type_tag == "b": - r.append(bool(struct.unpack(">B", - _read_exactly(self.port, 1))[0])) - if type_tag == "i": - r.append(struct.unpack(">l", _read_exactly(self.port, 4))[0]) - if type_tag == "I": - r.append(struct.unpack(">q", _read_exactly(self.port, 8))[0]) - if type_tag == "f": - r.append(struct.unpack(">d", _read_exactly(self.port, 8))[0]) - if type_tag == "F": - n, d = struct.unpack(">qq", _read_exactly(self.port, 16)) - r.append(Fraction(n, d)) - if type_tag == "l": - r.append(self._receive_rpc_values()) - - def _serve_rpc(self, rpc_map, user_exception_map): - rpc_num = struct.unpack(">h", _read_exactly(self.port, 2))[0] - args = self._receive_rpc_values() - logger.debug("rpc service: %d %r", rpc_num, args) - eid, r = self.rpc_wrapper.run_rpc( - user_exception_map, rpc_map[rpc_num], args) - _write_exactly(self.port, struct.pack(">ll", eid, r)) - logger.debug("rpc service: %d %r == %r", rpc_num, args, r) - - def _serve_exception(self, user_exception_map): - eid, p0, p1, p2 = struct.unpack(">lqqq", - _read_exactly(self.port, 4+3*8)) - self.rpc_wrapper.filter_rpc_exception(eid) - if eid < core_language.first_user_eid: - exception = runtime_exceptions.exception_map[eid] - raise exception(self.core, p0, p1, p2) - else: - exception = user_exception_map[eid] - raise exception - - def serve(self, rpc_map, user_exception_map): - while True: - msg = self._get_device_msg() - if msg == _D2HMsgType.RPC_REQUEST: - self._serve_rpc(rpc_map, user_exception_map) - elif msg == _D2HMsgType.KERNEL_EXCEPTION: - self._serve_exception(user_exception_map) - elif msg == _D2HMsgType.KERNEL_FINISHED: - return - else: - raise IOError("Incorrect request from device: "+str(msg)) + logger.debug("remote baud rate set to %d", baud) diff --git a/soc/runtime/comm_serial.c b/soc/runtime/comm_serial.c index 77bbcce95..fc4c7d841 100644 --- a/soc/runtime/comm_serial.c +++ b/soc/runtime/comm_serial.c @@ -12,33 +12,34 @@ /* host to device */ enum { - MSGTYPE_REQUEST_IDENT = 1, + MSGTYPE_SET_BAUD_RATE = 1, + + MSGTYPE_REQUEST_IDENT, + MSGTYPE_SWITCH_CLOCK, + MSGTYPE_LOAD_OBJECT, MSGTYPE_RUN_KERNEL, - MSGTYPE_SET_BAUD_RATE, - MSGTYPE_SWITCH_CLOCK, }; /* device to host */ enum { - MSGTYPE_LOG = 1, - MSGTYPE_MESSAGE_UNRECOGNIZED, + MSGTYPE_MESSAGE_UNRECOGNIZED = 1, + MSGTYPE_LOG, MSGTYPE_IDENT, + MSGTYPE_CLOCK_SWITCH_COMPLETED, + MSGTYPE_CLOCK_SWITCH_FAILED, MSGTYPE_OBJECT_LOADED, - MSGTYPE_INCORRECT_LENGTH, - MSGTYPE_CRC_FAILED, + MSGTYPE_OBJECT_INCORRECT_LENGTH, + MSGTYPE_OBJECT_CRC_FAILED, MSGTYPE_OBJECT_UNRECOGNIZED, MSGTYPE_KERNEL_FINISHED, - MSGTYPE_KERNEL_EXCEPTION, MSGTYPE_KERNEL_STARTUP_FAILED, + MSGTYPE_KERNEL_EXCEPTION, MSGTYPE_RPC_REQUEST, - - MSGTYPE_CLOCK_SWITCH_COMPLETED, - MSGTYPE_CLOCK_SWITCH_FAILED, }; static int receive_int(void) @@ -114,14 +115,14 @@ static void receive_and_load_object(object_loader load_object) length = receive_int(); if(length > sizeof(buffer)) { - send_char(MSGTYPE_INCORRECT_LENGTH); + send_char(MSGTYPE_OBJECT_INCORRECT_LENGTH); return; } crc = receive_int(); for(i=0;i (sizeof(kernel_name)-1)) { - send_char(MSGTYPE_INCORRECT_LENGTH); + send_char(MSGTYPE_OBJECT_INCORRECT_LENGTH); return; } for(i=0;i