diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 406a7c961..e7b6b4874 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -52,6 +52,9 @@ def rtio_get_destination_status(linkno: TInt32) -> TBool: def rtio_get_counter() -> TInt64: raise NotImplementedError("syscall not simulated") +@syscall +def raise_exception(id: TInt32) -> TNone: + raise NotImplementedError("syscall not simulated") class Core: """Core device driver. diff --git a/artiq/test/coredevice/test_exceptions.py b/artiq/test/coredevice/test_exceptions.py new file mode 100644 index 000000000..26d6f8ca3 --- /dev/null +++ b/artiq/test/coredevice/test_exceptions.py @@ -0,0 +1,59 @@ +import unittest +import artiq.coredevice.exceptions as exceptions + +from artiq.experiment import * +from artiq.test.hardware_testbench import ExperimentCase +from artiq.compiler.embedding import EmbeddingMap +from artiq.coredevice.core import raise_exception + +""" +Test sync in exceptions raised between host and kernel +Check artiq.compiler.embedding and artiq.frontend.ksupport.eh_artiq + +Considers the following two cases: + 1) Exception raised on kernel and passed to host + 2) Exception raised in host function called from kernel +Ensures integirty of exceptions is maintained as it passes between kernel and host +""" + +exception_names = EmbeddingMap().str_reverse_map + + +class _TestExceptionSync(EnvExperiment): + def build(self): + self.setattr_device("core") + + @rpc + def _raise_exception_host(self, id): + exn = exception_names[id].split('.')[-1].split(':')[-1] + exn = getattr(exceptions, exn) + raise exn + + @kernel + def raise_exception_host(self, id): + self._raise_exception_host(id) + + @kernel + def raise_exception_kernel(self, id): + raise_exception(id) + + +class ExceptionTest(ExperimentCase): + def test_raise_exceptions_kernel(self): + exp = self.create(_TestExceptionSync) + + for id, name in list(exception_names.items())[::-1]: + name = name.split('.')[-1].split(':')[-1] + with self.assertRaises(getattr(exceptions, name)) as ctx: + exp.raise_exception_kernel(id) + self.assertEqual(str(ctx.exception).strip("'"), name) + + + def test_raise_exceptions_host(self): + exp = self.create(_TestExceptionSync) + + for id, name in exception_names.items(): + name = name.split('.')[-1].split(':')[-1] + with self.assertRaises(getattr(exceptions, name)) as ctx: + exp.raise_exception_host(id) +