coredevice/test: add unittests for exceptions

This commit is contained in:
abdul124 2024-08-12 13:01:23 +08:00 committed by Sébastien Bourdeauducq
parent de6f83b009
commit d389f8e25b
2 changed files with 62 additions and 0 deletions

View File

@ -52,6 +52,9 @@ def rtio_get_destination_status(linkno: TInt32) -> TBool:
def rtio_get_counter() -> TInt64: def rtio_get_counter() -> TInt64:
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
@syscall
def raise_exception(id: TInt32) -> TNone:
raise NotImplementedError("syscall not simulated")
class Core: class Core:
"""Core device driver. """Core device driver.

View File

@ -0,0 +1,59 @@
import unittest
import artiq.coredevice.exceptions as exceptions
from artiq.experiment import *
from artiq.test.hardware_testbench import ExperimentCase
from artiq.compiler.embedding import EmbeddingMap
from artiq.coredevice.core import raise_exception
"""
Test sync in exceptions raised between host and kernel
Check artiq.compiler.embedding and artiq.frontend.ksupport.eh_artiq
Considers the following two cases:
1) Exception raised on kernel and passed to host
2) Exception raised in host function called from kernel
Ensures integirty of exceptions is maintained as it passes between kernel and host
"""
exception_names = EmbeddingMap().str_reverse_map
class _TestExceptionSync(EnvExperiment):
def build(self):
self.setattr_device("core")
@rpc
def _raise_exception_host(self, id):
exn = exception_names[id].split('.')[-1].split(':')[-1]
exn = getattr(exceptions, exn)
raise exn
@kernel
def raise_exception_host(self, id):
self._raise_exception_host(id)
@kernel
def raise_exception_kernel(self, id):
raise_exception(id)
class ExceptionTest(ExperimentCase):
def test_raise_exceptions_kernel(self):
exp = self.create(_TestExceptionSync)
for id, name in list(exception_names.items())[::-1]:
name = name.split('.')[-1].split(':')[-1]
with self.assertRaises(getattr(exceptions, name)) as ctx:
exp.raise_exception_kernel(id)
self.assertEqual(str(ctx.exception).strip("'"), name)
def test_raise_exceptions_host(self):
exp = self.create(_TestExceptionSync)
for id, name in exception_names.items():
name = name.split('.')[-1].split(':')[-1]
with self.assertRaises(getattr(exceptions, name)) as ctx:
exp.raise_exception_host(id)