diff --git a/nac3artiq/demo.py b/nac3artiq/demo.py index 8e9411f..9410b6f 100644 --- a/nac3artiq/demo.py +++ b/nac3artiq/demo.py @@ -63,4 +63,4 @@ class Demo: if __name__ == "__main__": - Demo().run() + run_on_core(Demo().run) diff --git a/nac3artiq/demo_host_obj.py b/nac3artiq/demo_host_obj.py new file mode 100644 index 0000000..b9916d0 --- /dev/null +++ b/nac3artiq/demo_host_obj.py @@ -0,0 +1,43 @@ +from language import * +from artiq_builtins import * +from numpy import int32, int64 + +test_int = 123 +test_float = 123.456 +test_list = [1] * 1 +test_list2 = [[1.1], [], [3.0]] +test_list_fail = [1, 2, 3.0] + +test_tuple = (1, 2, 3.0) + +@kernel +class Test: + a: int32 + + @kernel + def __init__(self, a: int32): + self.a = a + +test = Test(1) +print(test.a) + +@kernel +class Demo: + @kernel + def run(self): + while True: + delay_mu(round64(test_float * 2.0)) + delay_mu(int64(test_int)) + delay_mu(int64(test_list[0])) + # delay_mu(int64(test_list_fail[0])) + delay_mu(int64(test_tuple[0])) + delay_mu(int64(test_tuple[2])) + delay_mu(int64(test_list2[2][0])) + + delay_mu(int64(test.a)) + test.a = 10 + delay_mu(int64(test.a)) + + +if __name__ == "__main__": + run_on_core(Demo().run) diff --git a/nac3artiq/language.py b/nac3artiq/language.py index 73c79df..c479e75 100644 --- a/nac3artiq/language.py +++ b/nac3artiq/language.py @@ -1,12 +1,13 @@ from inspect import isclass, getmodule from functools import wraps +import sys import nac3artiq import device_db -__all__ = ["extern", "kernel"] +__all__ = ["extern", "kernel", "run_on_core"] nac3 = nac3artiq.NAC3(device_db.device_db["core"]["arguments"]["target"]) @@ -35,14 +36,15 @@ def kernel(class_or_function): nac3.register_module(module) registered_ids.add(module_id) - if isclass(class_or_function): - return class_or_function - else: - @wraps(class_or_function) - def run_on_core(self, *args, **kwargs): - global allow_module_registration - if allow_module_registration: - nac3.analyze() - allow_module_registration = False - nac3.compile_method(id(self.__class__), class_or_function.__name__) - return run_on_core + return class_or_function + +def get_defined_class(method): + return vars(sys.modules[method.__module__])[method.__qualname__.split('.')[0]] + +def run_on_core(method, *args, **kwargs): + global allow_module_registration + if allow_module_registration: + nac3.analyze() + allow_module_registration = False + nac3.compile_method(id(get_defined_class(method)), method.__name__) +