diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index d74253838..105e895e0 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -5,6 +5,7 @@ import types import builtins from fractions import Fraction from collections import OrderedDict +from functools import partial from artiq.language import core as core_language from artiq.language import units @@ -371,6 +372,52 @@ class HostObjectMapper: return {encoding: obj for i, (encoding, obj) in self._d.items()} +def get_attr_init(attribute_namespace, loc_node): + attr_init = [] + for (_, attr), attr_info in attribute_namespace.items(): + if hasattr(attr_info.obj, attr): + value = getattr(attr_info.obj, attr) + value = ast.copy_location(value_to_ast(value), loc_node) + target = ast.copy_location(ast.Name(attr_info.mangled_name, + ast.Store()), + loc_node) + assign = ast.copy_location(ast.Assign([target], value), + loc_node) + attr_init.append(assign) + return attr_init + + +def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node): + attr_writeback = [] + for (_, attr), attr_info in attribute_namespace.items(): + if attr_info.read_write: + # HACK/FIXME: since RPC of non-int is not supported yet, skip + # writeback of other types for now. + # This code breaks if an int is promoted to int64 + if hasattr(attr_info.obj, attr): + val = getattr(attr_info.obj, attr) + if (not isinstance(val, int) + or isinstance(val, core_language.int64)): + continue + # + + setter = partial(setattr, attr_info.obj, attr) + func = ast.copy_location( + ast.Name("syscall", ast.Load()), loc_node) + arg1 = ast.copy_location(ast.Str("rpc"), loc_node) + arg2 = ast.copy_location( + value_to_ast(rpc_mapper.encode(setter)), loc_node) + arg3 = ast.copy_location( + ast.Name(attr_info.mangled_name, ast.Load()), loc_node) + call = ast.copy_location( + ast.Call(func=func, args=[arg1, arg2, arg3], + keywords=[], starargs=None, kwargs=None), + loc_node) + expr = ast.copy_location(ast.Expr(call), loc_node) + attr_writeback.append(expr) + return attr_writeback + + def inline(core, k_function, k_args, k_kwargs): if k_kwargs: raise NotImplementedError( @@ -392,16 +439,8 @@ def inline(core, k_function, k_args, k_kwargs): func=k_function, args=k_args) - param_init = [] - for (_, attr), attr_info in attribute_namespace.items(): - value = getattr(attr_info.obj, attr) - value = ast.copy_location(value_to_ast(value), func_def) - target = ast.copy_location(ast.Name(attr_info.mangled_name, - ast.Store()), - func_def) - assign = ast.copy_location(ast.Assign([target], value), - func_def) - param_init.append(assign) - func_def.body[0:0] = param_init + func_def.body[0:0] = get_attr_init(attribute_namespace, func_def) + func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc, + func_def) return func_def, mappers.rpc.get_map(), mappers.exception.get_map() diff --git a/test/full_stack.py b/test/full_stack.py index 5f67cd6df..b5038ae87 100644 --- a/test/full_stack.py +++ b/test/full_stack.py @@ -36,6 +36,15 @@ class _Primes(AutoContext): self.output_list.append(x) +class _Attributes(AutoContext): + def build(self): + self.input = 84 + + @kernel + def run(self): + self.result = self.input//2 + + class _PulseLogger(AutoContext): parameters = "output_list name" @@ -123,13 +132,20 @@ class _Exceptions(AutoContext): self.trace.append(104) -class SimCompareCase(unittest.TestCase): +class ExecutionCase(unittest.TestCase): def test_primes(self): l_device, l_host = [], [] _run_on_device(_Primes, max=100, output_list=l_device) _run_on_host(_Primes, max=100, output_list=l_host) self.assertEqual(l_device, l_host) + def test_attributes(self): + with comm_serial.Comm() as comm: + coredev = core.Core(comm) + uut = _Attributes(core=coredev) + uut.run() + self.assertEqual(uut.result, 42) + def test_pulses(self): l_device, l_host = [], [] _run_on_device(_Pulses, output_list=l_device)