From f7232fd3d157d36c76c9d039fd2e0689eb331be6 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sat, 20 Dec 2014 21:33:22 +0800 Subject: [PATCH] support exceptions raised by RPCs --- artiq/coredevice/comm_serial.py | 17 +++++------ artiq/coredevice/rpc_wrapper.py | 40 ++++++++++++++++++++++++++ artiq/coredevice/runtime_exceptions.py | 14 +++++---- artiq/test/full_stack.py | 30 +++++++++++++++++++ soc/runtime/comm_serial.c | 11 ++++++- soc/runtime/exceptions.h | 10 ++++--- 6 files changed, 103 insertions(+), 19 deletions(-) create mode 100644 artiq/coredevice/rpc_wrapper.py diff --git a/artiq/coredevice/comm_serial.py b/artiq/coredevice/comm_serial.py index 6c7d0d35d..ce4ec135a 100644 --- a/artiq/coredevice/comm_serial.py +++ b/artiq/coredevice/comm_serial.py @@ -10,6 +10,7 @@ from artiq.language import units from artiq.language.context import * from artiq.coredevice.runtime import Environment from artiq.coredevice import runtime_exceptions +from artiq.coredevice.rpc_wrapper import RPCWrapper logger = logging.getLogger(__name__) @@ -69,6 +70,7 @@ class Comm(AutoContext): self.port.flush() self.set_remote_baud(self.baud_rate) self.set_baud(self.baud_rate) + self.rpc_wrapper = RPCWrapper() def set_baud(self, baud): self.port.baudrate = baud @@ -183,19 +185,18 @@ class Comm(AutoContext): if type_tag == "l": r.append(self._receive_rpc_values()) - def _serve_rpc(self, rpc_map): - (rpc_num, ) = struct.unpack(">h", _read_exactly(self.port, 2)) + def _serve_rpc(self, rpc_map, user_exception_map): + rpc_num = struct.unpack(">h", _read_exactly(self.port, 2))[0] args = self._receive_rpc_values() logger.debug("rpc service: {} ({})".format(rpc_num, args)) - r = rpc_map[rpc_num](*args) - if r is None: - r = 0 - _write_exactly(self.port, struct.pack(">l", r)) + eid, r = self.rpc_wrapper.run_rpc(user_exception_map, rpc_map[rpc_num], args) + _write_exactly(self.port, struct.pack(">ll", eid, r)) logger.debug("rpc service: {} ({}) == {}".format( rpc_num, args, r)) def _serve_exception(self, user_exception_map): - (eid, ) = struct.unpack(">l", _read_exactly(self.port, 4)) + eid = struct.unpack(">l", _read_exactly(self.port, 4))[0] + self.rpc_wrapper.filter_rpc_exception(eid) if eid < core_language.first_user_eid: exception = runtime_exceptions.exception_map[eid] else: @@ -206,7 +207,7 @@ class Comm(AutoContext): while True: msg = self._get_device_msg() if msg == _D2HMsgType.RPC_REQUEST: - self._serve_rpc(rpc_map) + self._serve_rpc(rpc_map, user_exception_map) elif msg == _D2HMsgType.KERNEL_EXCEPTION: self._serve_exception(user_exception_map) elif msg == _D2HMsgType.KERNEL_FINISHED: diff --git a/artiq/coredevice/rpc_wrapper.py b/artiq/coredevice/rpc_wrapper.py new file mode 100644 index 000000000..eeae17286 --- /dev/null +++ b/artiq/coredevice/rpc_wrapper.py @@ -0,0 +1,40 @@ +from artiq.coredevice.runtime_exceptions import exception_map, _RPCException + + +def _lookup_exception(d, e): + for eid, exception in d.items(): + if isinstance(e, exception): + return eid + return 0 + + +class RPCWrapper: + def __init__(self): + self.last_exception = None + + def run_rpc(self, user_exception_map, fn, args): + eid = 0 + r = None + + try: + r = fn(*args) + except Exception as e: + eid = _lookup_exception(user_exception_map, e) + if not eid: + eid = _lookup_exception(exception_map, e) + if eid: + self.last_exception = None + else: + self.last_exception = e + eid = _RPCException.eid + + if r is None: + r = 0 + else: + r = int(r) + + return eid, r + + def filter_rpc_exception(self, eid): + if eid == _RPCException.eid: + raise self.last_exception diff --git a/artiq/coredevice/runtime_exceptions.py b/artiq/coredevice/runtime_exceptions.py index c130a1488..45f912800 100644 --- a/artiq/coredevice/runtime_exceptions.py +++ b/artiq/coredevice/runtime_exceptions.py @@ -9,7 +9,11 @@ class OutOfMemory(RuntimeException): """Raised when the runtime fails to allocate memory. """ - eid = 0 + eid = 1 + + +class _RPCException(RuntimeException): + eid = 2 class RTIOUnderflow(RuntimeException): @@ -19,11 +23,9 @@ class RTIOUnderflow(RuntimeException): The offending event is discarded and the RTIO core keeps operating. """ - eid = 1 + eid = 3 -# Raised by RTIO driver for regular RTIO. -# Raised by runtime for DDS FUD. class RTIOSequenceError(RuntimeException): """Raised when an event is submitted on a given channel with a timestamp not larger than the previous one. @@ -31,7 +33,7 @@ class RTIOSequenceError(RuntimeException): The offending event is discarded and the RTIO core keeps operating. """ - eid = 2 + eid = 4 class RTIOOverflow(RuntimeException): @@ -43,7 +45,7 @@ class RTIOOverflow(RuntimeException): the exception is caught, and events will be partially retrieved. """ - eid = 3 + eid = 5 exception_map = {e.eid: e for e in globals().values() diff --git a/artiq/test/full_stack.py b/artiq/test/full_stack.py index 3e89063ec..395be1e46 100644 --- a/artiq/test/full_stack.py +++ b/artiq/test/full_stack.py @@ -170,6 +170,25 @@ class _Exceptions(AutoContext): self.trace.append(104) +class _RPCExceptions(AutoContext): + def build(self): + self.success = False + + def exception_raiser(self): + raise _MyException + + @kernel + def do_not_catch(self): + self.exception_raiser() + + @kernel + def catch(self): + try: + self.exception_raiser() + except _MyException: + self.success = True + + @unittest.skipIf(no_hardware, "no hardware") class ExecutionCase(unittest.TestCase): def test_primes(self): @@ -219,6 +238,17 @@ class ExecutionCase(unittest.TestCase): _run_on_host(_Exceptions, trace=t_host) self.assertEqual(t_device, t_host) + def test_rpc_exceptions(self): + comm = comm_serial.Comm() + try: + uut = _RPCExceptions(core=core.Core(comm=comm)) + with self.assertRaises(_MyException): + uut.do_not_catch() + uut.catch() + self.assertTrue(uut.success) + finally: + comm.close() + class _RTIOLoopback(AutoContext): i = Device("ttl_in") diff --git a/soc/runtime/comm_serial.c b/soc/runtime/comm_serial.c index b234555c4..47d041bfc 100644 --- a/soc/runtime/comm_serial.c +++ b/soc/runtime/comm_serial.c @@ -4,6 +4,7 @@ #include #include "comm.h" +#include "exceptions.h" /* host to device */ enum { @@ -224,6 +225,8 @@ static int send_value(int type_tag, void *value) int comm_rpc(int rpc_num, ...) { int type_tag; + int eid; + int retval; send_char(MSGTYPE_RPC_REQUEST); send_sint(rpc_num); @@ -235,7 +238,13 @@ int comm_rpc(int rpc_num, ...) va_end(args); send_char(0); - return receive_int(); + eid = receive_int(); + retval = receive_int(); + + if(eid != EID_NONE) + exception_raise(eid); + + return retval; } void comm_log(const char *fmt, ...) diff --git a/soc/runtime/exceptions.h b/soc/runtime/exceptions.h index 39a2ef9e8..d55d6cfe5 100644 --- a/soc/runtime/exceptions.h +++ b/soc/runtime/exceptions.h @@ -2,10 +2,12 @@ #define __EXCEPTIONS_H enum { - EID_OUT_OF_MEMORY = 0, - EID_RTIO_UNDERFLOW = 1, - EID_RTIO_SEQUENCE_ERROR = 2, - EID_RTIO_OVERFLOW = 3, + EID_NONE = 0, + EID_OUT_OF_MEMORY = 1, + EID_RPC_EXCEPTION = 2, + EID_RTIO_UNDERFLOW = 3, + EID_RTIO_SEQUENCE_ERROR = 4, + EID_RTIO_OVERFLOW = 5, }; int exception_setjmp(void *jb) __attribute__((returns_twice));