1
0
forked from M-Labs/nac3

Add test case for passing strings into kernel invokation breaks

This commit is contained in:
ram 2025-01-10 08:14:42 +00:00
parent 29130d3ef4
commit b4b3980ffb
2 changed files with 8 additions and 72 deletions

Binary file not shown.

View File

@ -1,84 +1,20 @@
from min_artiq import * from min_artiq import *
from numpy import int32 from numpy import int32
# Simulate CacheError from ARTIQ
class CacheError(Exception):
pass
@nac3 @nac3
class _Cache: class StringListTest:
core: KernelInvariant[Core] core: KernelInvariant[Core]
# Properly declare core_cache as a KernelInvariant
core_cache: KernelInvariant[dict[str, list[int32]]]
def __init__(self): def __init__(self):
self.core = Core() self.core = Core()
self.core_cache = {}
def build(self): @kernel
# Simplified version of EnvExperiment.build() def test_string_and_list(self, key: str, value: list[int32]):
# Core and core_cache are already set in __init__ print_int32(int32(42))
pass
@kernel def test_params():
def get(self, key: str) -> list[int32]: exp = StringListTest()
return self.core_cache.get(key, []) exp.test_string_and_list("x4", [])
@kernel
def put(self, key: str, value: list[int32]):
self.core_cache[key] = value
@kernel
def get_put(self, key: str, value: list[int32]):
self.get(key)
self.put(key, value)
class CacheTest:
def create(self, cls):
return cls()
def assertEqual(self, a, b):
assert a == b, f"Expected {a} to equal {b}"
def assertRaises(self, exc_type):
class RaiseContext:
def __init__(self, test_case, expected_exc):
self.test_case = test_case
self.expected_exc = expected_exc
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
raise AssertionError(f"Expected {self.expected_exc}")
if not issubclass(exc_type, self.expected_exc):
raise AssertionError(f"Expected {self.expected_exc}, got {exc_type}")
return True
return RaiseContext(self, exc_type)
def test_get_empty(self):
exp = self.create(_Cache)
self.assertEqual(exp.get("x1"), [])
def test_put_get(self):
exp = self.create(_Cache)
exp.put("x2", [int32(1), int32(2), int32(3)])
self.assertEqual(exp.get("x2"), [int32(1), int32(2), int32(3)])
def test_replace(self):
exp = self.create(_Cache)
exp.put("x3", [int32(1), int32(2), int32(3)])
exp.put("x3", [int32(1), int32(2), int32(3), int32(4), int32(5)])
self.assertEqual(exp.get("x3"), [int32(1), int32(2), int32(3), int32(4), int32(5)])
def test_borrow(self):
exp = self.create(_Cache)
exp.put("x4", [int32(1), int32(2), int32(3)])
with self.assertRaises(CacheError):
exp.get_put("x4", [])
if __name__ == "__main__": if __name__ == "__main__":
test = CacheTest() test_params()
test.test_get_empty()
test.test_put_get()
test.test_replace()
test.test_borrow()