From 37886fcfe3755360c448b1e49c5ae5afe86f2444 Mon Sep 17 00:00:00 2001 From: ram Date: Fri, 10 Jan 2025 06:09:38 +0000 Subject: [PATCH] Add string registration in NAC --- nac3artiq/demo/module.elf | Bin 0 -> 1452 bytes nac3artiq/demo/str_test.py | 20 +++++++++ nac3artiq/demo/test_cache.py | 84 +++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 nac3artiq/demo/module.elf create mode 100644 nac3artiq/demo/str_test.py create mode 100644 nac3artiq/demo/test_cache.py diff --git a/nac3artiq/demo/module.elf b/nac3artiq/demo/module.elf new file mode 100644 index 0000000000000000000000000000000000000000..21483fee269ff20d20e0f85f9b79e0e5a29d7a53 GIT binary patch literal 1452 zcmah}OKTHR6h61}8mM06o`Q!t8Pk?2Mj!tL$MGzm?nVP=Fx3L1Ri zvf!eEF1qk%2qNe|@CS6ExarOx5WjDn6NjRB;JfGbo%@(`?>xD&cGDOmDkb?T^u7XD zfSr=abY|COM$XHE%x8IoCPFSjo>XK~ukjC#)yh-CTpf5f9*;p!VI&-q6cOrSPrSBE+TtaF zOC5q3y~cgeH@&|!k43K6@x?>qP}IguSa`Cn{~7#-1kbH+s(4@ZyYl?xW#26?;=52$ z-dO~cw+Mr{sw8OT@h&K@_|}xFyi0iv*Ih>eu?(C0p^(Qu@b2(lkjFj}?4JtUOs->O zPsuw#km$qa`N?6giQ5401bMt4#47CKx?KA%z&ft2{~Xq<@V-zU02vSS?B~!kGOzXk oHuu8$ivK9aDCBX@QLf438`5_KHxr%jB|LIe$lFy1ntP)B0{D%hr2qf` literal 0 HcmV?d00001 diff --git a/nac3artiq/demo/str_test.py b/nac3artiq/demo/str_test.py new file mode 100644 index 0000000..30b80a2 --- /dev/null +++ b/nac3artiq/demo/str_test.py @@ -0,0 +1,20 @@ +from min_artiq import * +from numpy import int32 + +@nac3 +class StrTest: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @kernel + def test_string(self, s: str): + # Just print the string to verify it works + print_int32(int32(42)) + + def run(self): + self.test_string("hello") + +if __name__ == "__main__": + StrTest().run() diff --git a/nac3artiq/demo/test_cache.py b/nac3artiq/demo/test_cache.py new file mode 100644 index 0000000..75d226c --- /dev/null +++ b/nac3artiq/demo/test_cache.py @@ -0,0 +1,84 @@ +from min_artiq import * +from numpy import int32 + +# Simulate CacheError from ARTIQ +class CacheError(Exception): + pass + +@nac3 +class _Cache: + core: KernelInvariant[Core] + # Properly declare core_cache as a KernelInvariant + core_cache: KernelInvariant[dict[str, list[int32]]] + + def __init__(self): + self.core = Core() + self.core_cache = {} + + def build(self): + # Simplified version of EnvExperiment.build() + # Core and core_cache are already set in __init__ + pass + + @kernel + def get(self, key: str) -> list[int32]: + return self.core_cache.get(key, []) + + @kernel + def put(self, key: str, value: list[int32]): + self.core_cache[key] = value + + @kernel + def get_put(self, key: str, value: list[int32]): + self.get(key) + self.put(key, value) + +class CacheTest: + def create(self, cls): + return cls() + + def assertEqual(self, a, b): + assert a == b, f"Expected {a} to equal {b}" + + def assertRaises(self, exc_type): + class RaiseContext: + def __init__(self, test_case, expected_exc): + self.test_case = test_case + self.expected_exc = expected_exc + def __enter__(self): + return self + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + raise AssertionError(f"Expected {self.expected_exc}") + if not issubclass(exc_type, self.expected_exc): + raise AssertionError(f"Expected {self.expected_exc}, got {exc_type}") + return True + return RaiseContext(self, exc_type) + + def test_get_empty(self): + exp = self.create(_Cache) + self.assertEqual(exp.get("x1"), []) + + def test_put_get(self): + exp = self.create(_Cache) + exp.put("x2", [int32(1), int32(2), int32(3)]) + self.assertEqual(exp.get("x2"), [int32(1), int32(2), int32(3)]) + + def test_replace(self): + exp = self.create(_Cache) + exp.put("x3", [int32(1), int32(2), int32(3)]) + exp.put("x3", [int32(1), int32(2), int32(3), int32(4), int32(5)]) + self.assertEqual(exp.get("x3"), [int32(1), int32(2), int32(3), int32(4), int32(5)]) + + def test_borrow(self): + exp = self.create(_Cache) + exp.put("x4", [int32(1), int32(2), int32(3)]) + with self.assertRaises(CacheError): + exp.get_put("x4", []) + +if __name__ == "__main__": + test = CacheTest() + test.test_get_empty() + test.test_put_get() + test.test_replace() + test.test_borrow()