forked from M-Labs/artiq
1
0
Fork 0

coredevice/comm: split protocol to allow reuse for Ethernet

This commit is contained in:
Sebastien Bourdeauducq 2015-04-10 00:59:35 +08:00
parent 44304a33b2
commit cb2596bd81
3 changed files with 249 additions and 196 deletions

View File

@ -0,0 +1,202 @@
import struct
import zlib
import logging
from enum import Enum
from fractions import Fraction
from artiq.language import units
from artiq.coredevice.runtime import Environment
from artiq.coredevice import runtime_exceptions
from artiq.language import core as core_language
from artiq.coredevice.rpc_wrapper import RPCWrapper
logger = logging.getLogger(__name__)
class _H2DMsgType(Enum):
LINK_MESSAGE = 1
REQUEST_IDENT = 2
SWITCH_CLOCK = 3
LOAD_OBJECT = 4
RUN_KERNEL = 5
class _D2HMsgType(Enum):
MESSAGE_UNRECOGNIZED = 1
LOG = 2
IDENT = 3
CLOCK_SWITCH_COMPLETED = 4
CLOCK_SWITCH_FAILED = 5
OBJECT_LOADED = 6
OBJECT_INCORRECT_LENGTH = 7
OBJECT_CRC_FAILED = 8
OBJECT_UNRECOGNIZED = 9
KERNEL_FINISHED = 10
KERNEL_STARTUP_FAILED = 11
KERNEL_EXCEPTION = 12
RPC_REQUEST = 13
class UnsupportedDevice(Exception):
pass
class CommGeneric:
# methods for derived classes to implement
def open(self):
"""Opens the communication channel.
Must do nothing if already opened."""
raise NotImplementedError
def close(self):
"""Closes the communication channel.
Must do nothing if already closed."""
raise NotImplementedError
def read(self, length):
"""Reads exactly length bytes from the communication channel.
The channel is assumed to be opened."""
raise NotImplementedError
def write(self, data):
"""Writes exactly length bytes to the communication channel.
The channel is assumed to be opened."""
raise NotImplementedError
#
def _read(self, length):
self.open()
return self.read(length)
def _write(self, data):
self.open()
self.write(data)
def _get_device_msg(self):
while True:
(reply, ) = struct.unpack("B", self._read(1))
msg = _D2HMsgType(reply)
if msg == _D2HMsgType.LOG:
(length, ) = struct.unpack(">h", self._read(2))
log_message = ""
for i in range(length):
(c, ) = struct.unpack("B", self._read(1))
log_message += chr(c)
logger.info("DEVICE LOG: %s", log_message)
else:
logger.debug("message received: %r", msg)
return msg
def get_runtime_env(self):
self._write(struct.pack(">lb", 0x5a5a5a5a,
_H2DMsgType.REQUEST_IDENT.value))
msg = self._get_device_msg()
if msg != _D2HMsgType.IDENT:
raise IOError("Incorrect reply from device: {}".format(msg))
(reply, ) = struct.unpack("B", self._read(1))
runtime_id = chr(reply)
for i in range(3):
(reply, ) = struct.unpack("B", self._read(1))
runtime_id += chr(reply)
if runtime_id != "AROR":
raise UnsupportedDevice("Unsupported runtime ID: {}"
.format(runtime_id))
ref_freq_i, ref_freq_fn, ref_freq_fd = struct.unpack(
">lBB", self._read(6))
ref_freq = (ref_freq_i + Fraction(ref_freq_fn, ref_freq_fd))*units.Hz
ref_period = 1/ref_freq
logger.debug("environment ref_period: %s", ref_period)
return Environment(ref_period)
def switch_clock(self, external):
self._write(struct.pack(
">lbb", 0x5a5a5a5a, _H2DMsgType.SWITCH_CLOCK.value,
int(external)))
msg = self._get_device_msg()
if msg != _D2HMsgType.CLOCK_SWITCH_COMPLETED:
raise IOError("Incorrect reply from device: {}".format(msg))
def load(self, kcode):
self._write(struct.pack(
">lblL",
0x5a5a5a5a, _H2DMsgType.LOAD_OBJECT.value,
len(kcode), zlib.crc32(kcode)))
self._write(kcode)
msg = self._get_device_msg()
if msg != _D2HMsgType.OBJECT_LOADED:
raise IOError("Incorrect reply from device: "+str(msg))
def run(self, kname):
self._write(struct.pack(
">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname)))
for c in kname:
self._write(struct.pack(">B", ord(c)))
logger.debug("running kernel: %s", kname)
def _receive_rpc_values(self):
r = []
while True:
type_tag = chr(struct.unpack(">B", self._read(1))[0])
if type_tag == "\x00":
return r
if type_tag == "n":
r.append(None)
if type_tag == "b":
r.append(bool(struct.unpack(">B", self._read(1))[0]))
if type_tag == "i":
r.append(struct.unpack(">l", self._read(4))[0])
if type_tag == "I":
r.append(struct.unpack(">q", self._read(8))[0])
if type_tag == "f":
r.append(struct.unpack(">d", self._read(8))[0])
if type_tag == "F":
n, d = struct.unpack(">qq", self._read(16))
r.append(Fraction(n, d))
if type_tag == "l":
r.append(self._receive_rpc_values())
def _serve_rpc(self, rpc_wrapper, rpc_map, user_exception_map):
rpc_num = struct.unpack(">h", self._read(2))[0]
args = self._receive_rpc_values()
logger.debug("rpc service: %d %r", rpc_num, args)
eid, r = rpc_wrapper.run_rpc(
user_exception_map, rpc_map[rpc_num], args)
self._write(struct.pack(">ll", eid, r))
logger.debug("rpc service: %d %r == %r", rpc_num, args, r)
def _serve_exception(self, rpc_wrapper, user_exception_map):
eid, p0, p1, p2 = struct.unpack(">lqqq", self._read(4+3*8))
rpc_wrapper.filter_rpc_exception(eid)
if eid < core_language.first_user_eid:
exception = runtime_exceptions.exception_map[eid]
raise exception(self.core, p0, p1, p2)
else:
exception = user_exception_map[eid]
raise exception
def serve(self, rpc_map, user_exception_map):
rpc_wrapper = RPCWrapper()
while True:
msg = self._get_device_msg()
if msg == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(rpc_wrapper, rpc_map, user_exception_map)
elif msg == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(rpc_wrapper, user_exception_map)
elif msg == _D2HMsgType.KERNEL_FINISHED:
return
else:
raise IOError("Incorrect request from device: "+str(msg))
def send_link_message(self, data):
self._write(struct.pack(
">lb", 0x5a5a5a5a, _H2DMsgType.LINK_MESSAGE.value))
self._write(data)

View File

@ -1,89 +1,59 @@
import logging
import serial import serial
import struct import struct
import zlib
from enum import Enum
from fractions import Fraction
import logging
from artiq.language import core as core_language from artiq.coredevice.comm_generic import CommGeneric
from artiq.language import units
from artiq.language.db import * from artiq.language.db import *
from artiq.coredevice.runtime import Environment
from artiq.coredevice import runtime_exceptions
from artiq.coredevice.rpc_wrapper import RPCWrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnsupportedDevice(Exception): class Comm(CommGeneric, AutoDB):
pass
class _H2DMsgType(Enum):
REQUEST_IDENT = 1
LOAD_OBJECT = 2
RUN_KERNEL = 3
SET_BAUD_RATE = 4
SWITCH_CLOCK = 5
class _D2HMsgType(Enum):
LOG = 1
MESSAGE_UNRECOGNIZED = 2
IDENT = 3
OBJECT_LOADED = 4
INCORRECT_LENGTH = 5
CRC_FAILED = 6
OBJECT_UNRECOGNIZED = 7
KERNEL_FINISHED = 8
KERNEL_EXCEPTION = 9
KERNEL_STARTUP_FAILED = 10
RPC_REQUEST = 11
CLOCK_SWITCH_COMPLETED = 12
CLOCK_SWITCH_FAILED = 13
def _write_exactly(f, data):
remaining = len(data)
pos = 0
while remaining:
written = f.write(data[pos:])
remaining -= written
pos += written
def _read_exactly(f, n):
r = bytes()
while(len(r) < n):
r += f.read(n - len(r))
return r
class Comm(AutoDB):
class DBKeys: class DBKeys:
serial_dev = Parameter("/dev/ttyUSB1") serial_dev = Parameter()
baud_rate = Parameter(115200) baud_rate = Parameter(115200)
def build(self): def open(self):
if hasattr(self, "port"):
return
self.port = serial.serial_for_url(self.serial_dev, baudrate=115200) self.port = serial.serial_for_url(self.serial_dev, baudrate=115200)
self.port.flush() self.port.flush()
self.set_remote_baud(self.baud_rate) self.set_remote_baud(self.baud_rate)
self.set_baud(self.baud_rate) self.set_baud(self.baud_rate)
self.rpc_wrapper = RPCWrapper()
def close(self):
if not hasattr(self, "port"):
return
self.set_remote_baud(115200)
self.port.close()
del self.port
def read(self, length):
r = bytes()
while(len(r) < length):
r += self.port.read(length - len(r))
return r
def write(self, data):
remaining = len(data)
pos = 0
while remaining:
written = self.port.write(data[pos:])
remaining -= written
pos += written
def set_baud(self, baud): def set_baud(self, baud):
self.port.baudrate = baud self.port.baudrate = baud
self.port.flush() self.port.flush()
logger.debug("baud rate set to %d", baud) logger.debug("local baud rate set to %d", baud)
def set_remote_baud(self, baud): def set_remote_baud(self, baud):
_write_exactly(self.port, struct.pack( self.send_link_message(struct.pack(">l", baud))
">lbl", 0x5a5a5a5a, _H2DMsgType.SET_BAUD_RATE.value, baud))
handshake = 0 handshake = 0
fails = 0 fails = 0
while handshake < 4: while handshake < 4:
(recv, ) = struct.unpack("B", _read_exactly(self.port, 1)) (recv, ) = struct.unpack("B", self.read(1))
if recv == 0x5a: if recv == 0x5a:
handshake += 1 handshake += 1
else: else:
@ -95,124 +65,4 @@ class Comm(AutoDB):
fails += 1 fails += 1
if fails > 3: if fails > 3:
raise IOError("Baudrate ack failed") raise IOError("Baudrate ack failed")
self.set_baud(baud) logger.debug("remote baud rate set to %d", baud)
logger.debug("synchronized")
def close(self):
self.set_remote_baud(115200)
self.port.close()
def _get_device_msg(self):
while True:
(reply, ) = struct.unpack("B", _read_exactly(self.port, 1))
msg = _D2HMsgType(reply)
if msg == _D2HMsgType.LOG:
(length, ) = struct.unpack(">h", _read_exactly(self.port, 2))
log_message = ""
for i in range(length):
(c, ) = struct.unpack("B", _read_exactly(self.port, 1))
log_message += chr(c)
logger.info("DEVICE LOG: %s", log_message)
else:
logger.debug("message received: %r", msg)
return msg
def get_runtime_env(self):
_write_exactly(self.port, struct.pack(
">lb", 0x5a5a5a5a, _H2DMsgType.REQUEST_IDENT.value))
msg = self._get_device_msg()
if msg != _D2HMsgType.IDENT:
raise IOError("Incorrect reply from device: "+str(msg))
(reply, ) = struct.unpack("B", _read_exactly(self.port, 1))
runtime_id = chr(reply)
for i in range(3):
(reply, ) = struct.unpack("B", _read_exactly(self.port, 1))
runtime_id += chr(reply)
if runtime_id != "AROR":
raise UnsupportedDevice("Unsupported runtime ID: "+runtime_id)
ref_freq_i, ref_freq_fn, ref_freq_fd = struct.unpack(
">lBB", _read_exactly(self.port, 6))
ref_freq = (ref_freq_i + Fraction(ref_freq_fn, ref_freq_fd))*units.Hz
ref_period = 1/ref_freq
logger.debug("environment ref_period: %s", ref_period)
return Environment(ref_period)
def switch_clock(self, external):
_write_exactly(self.port, struct.pack(
">lbb", 0x5a5a5a5a, _H2DMsgType.SWITCH_CLOCK.value,
int(external)))
msg = self._get_device_msg()
if msg != _D2HMsgType.CLOCK_SWITCH_COMPLETED:
raise IOError("Incorrect reply from device: "+str(msg))
def load(self, kcode):
_write_exactly(self.port, struct.pack(
">lblL",
0x5a5a5a5a, _H2DMsgType.LOAD_OBJECT.value,
len(kcode), zlib.crc32(kcode)))
_write_exactly(self.port, kcode)
msg = self._get_device_msg()
if msg != _D2HMsgType.OBJECT_LOADED:
raise IOError("Incorrect reply from device: "+str(msg))
def run(self, kname):
_write_exactly(self.port, struct.pack(
">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname)))
for c in kname:
_write_exactly(self.port, struct.pack(">B", ord(c)))
logger.debug("running kernel: %s", kname)
def _receive_rpc_values(self):
r = []
while True:
type_tag = chr(struct.unpack(">B", _read_exactly(self.port, 1))[0])
if type_tag == "\x00":
return r
if type_tag == "n":
r.append(None)
if type_tag == "b":
r.append(bool(struct.unpack(">B",
_read_exactly(self.port, 1))[0]))
if type_tag == "i":
r.append(struct.unpack(">l", _read_exactly(self.port, 4))[0])
if type_tag == "I":
r.append(struct.unpack(">q", _read_exactly(self.port, 8))[0])
if type_tag == "f":
r.append(struct.unpack(">d", _read_exactly(self.port, 8))[0])
if type_tag == "F":
n, d = struct.unpack(">qq", _read_exactly(self.port, 16))
r.append(Fraction(n, d))
if type_tag == "l":
r.append(self._receive_rpc_values())
def _serve_rpc(self, rpc_map, user_exception_map):
rpc_num = struct.unpack(">h", _read_exactly(self.port, 2))[0]
args = self._receive_rpc_values()
logger.debug("rpc service: %d %r", rpc_num, args)
eid, r = self.rpc_wrapper.run_rpc(
user_exception_map, rpc_map[rpc_num], args)
_write_exactly(self.port, struct.pack(">ll", eid, r))
logger.debug("rpc service: %d %r == %r", rpc_num, args, r)
def _serve_exception(self, user_exception_map):
eid, p0, p1, p2 = struct.unpack(">lqqq",
_read_exactly(self.port, 4+3*8))
self.rpc_wrapper.filter_rpc_exception(eid)
if eid < core_language.first_user_eid:
exception = runtime_exceptions.exception_map[eid]
raise exception(self.core, p0, p1, p2)
else:
exception = user_exception_map[eid]
raise exception
def serve(self, rpc_map, user_exception_map):
while True:
msg = self._get_device_msg()
if msg == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(rpc_map, user_exception_map)
elif msg == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(user_exception_map)
elif msg == _D2HMsgType.KERNEL_FINISHED:
return
else:
raise IOError("Incorrect request from device: "+str(msg))

View File

@ -12,33 +12,34 @@
/* host to device */ /* host to device */
enum { enum {
MSGTYPE_REQUEST_IDENT = 1, MSGTYPE_SET_BAUD_RATE = 1,
MSGTYPE_REQUEST_IDENT,
MSGTYPE_SWITCH_CLOCK,
MSGTYPE_LOAD_OBJECT, MSGTYPE_LOAD_OBJECT,
MSGTYPE_RUN_KERNEL, MSGTYPE_RUN_KERNEL,
MSGTYPE_SET_BAUD_RATE,
MSGTYPE_SWITCH_CLOCK,
}; };
/* device to host */ /* device to host */
enum { enum {
MSGTYPE_LOG = 1, MSGTYPE_MESSAGE_UNRECOGNIZED = 1,
MSGTYPE_MESSAGE_UNRECOGNIZED, MSGTYPE_LOG,
MSGTYPE_IDENT, MSGTYPE_IDENT,
MSGTYPE_CLOCK_SWITCH_COMPLETED,
MSGTYPE_CLOCK_SWITCH_FAILED,
MSGTYPE_OBJECT_LOADED, MSGTYPE_OBJECT_LOADED,
MSGTYPE_INCORRECT_LENGTH, MSGTYPE_OBJECT_INCORRECT_LENGTH,
MSGTYPE_CRC_FAILED, MSGTYPE_OBJECT_CRC_FAILED,
MSGTYPE_OBJECT_UNRECOGNIZED, MSGTYPE_OBJECT_UNRECOGNIZED,
MSGTYPE_KERNEL_FINISHED, MSGTYPE_KERNEL_FINISHED,
MSGTYPE_KERNEL_EXCEPTION,
MSGTYPE_KERNEL_STARTUP_FAILED, MSGTYPE_KERNEL_STARTUP_FAILED,
MSGTYPE_KERNEL_EXCEPTION,
MSGTYPE_RPC_REQUEST, MSGTYPE_RPC_REQUEST,
MSGTYPE_CLOCK_SWITCH_COMPLETED,
MSGTYPE_CLOCK_SWITCH_FAILED,
}; };
static int receive_int(void) static int receive_int(void)
@ -114,14 +115,14 @@ static void receive_and_load_object(object_loader load_object)
length = receive_int(); length = receive_int();
if(length > sizeof(buffer)) { if(length > sizeof(buffer)) {
send_char(MSGTYPE_INCORRECT_LENGTH); send_char(MSGTYPE_OBJECT_INCORRECT_LENGTH);
return; return;
} }
crc = receive_int(); crc = receive_int();
for(i=0;i<length;i++) for(i=0;i<length;i++)
buffer[i] = receive_char(); buffer[i] = receive_char();
if(crc32(buffer, length) != crc) { if(crc32(buffer, length) != crc) {
send_char(MSGTYPE_CRC_FAILED); send_char(MSGTYPE_OBJECT_CRC_FAILED);
return; return;
} }
if(load_object(buffer, length)) if(load_object(buffer, length))
@ -140,7 +141,7 @@ static void receive_and_run_kernel(kernel_runner run_kernel)
length = receive_int(); length = receive_int();
if(length > (sizeof(kernel_name)-1)) { if(length > (sizeof(kernel_name)-1)) {
send_char(MSGTYPE_INCORRECT_LENGTH); send_char(MSGTYPE_OBJECT_INCORRECT_LENGTH);
return; return;
} }
for(i=0;i<length;i++) for(i=0;i<length;i++)