forked from M-Labs/artiq
transforms/inline: object attribute writeback
This commit is contained in:
parent
f54a2f93d2
commit
e9e12adceb
@ -5,6 +5,7 @@ import types
|
||||
import builtins
|
||||
from fractions import Fraction
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
@ -371,6 +372,52 @@ class HostObjectMapper:
|
||||
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
||||
|
||||
|
||||
def get_attr_init(attribute_namespace, loc_node):
|
||||
attr_init = []
|
||||
for (_, attr), attr_info in attribute_namespace.items():
|
||||
if hasattr(attr_info.obj, attr):
|
||||
value = getattr(attr_info.obj, attr)
|
||||
value = ast.copy_location(value_to_ast(value), loc_node)
|
||||
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
||||
ast.Store()),
|
||||
loc_node)
|
||||
assign = ast.copy_location(ast.Assign([target], value),
|
||||
loc_node)
|
||||
attr_init.append(assign)
|
||||
return attr_init
|
||||
|
||||
|
||||
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||
attr_writeback = []
|
||||
for (_, attr), attr_info in attribute_namespace.items():
|
||||
if attr_info.read_write:
|
||||
# HACK/FIXME: since RPC of non-int is not supported yet, skip
|
||||
# writeback of other types for now.
|
||||
# This code breaks if an int is promoted to int64
|
||||
if hasattr(attr_info.obj, attr):
|
||||
val = getattr(attr_info.obj, attr)
|
||||
if (not isinstance(val, int)
|
||||
or isinstance(val, core_language.int64)):
|
||||
continue
|
||||
#
|
||||
|
||||
setter = partial(setattr, attr_info.obj, attr)
|
||||
func = ast.copy_location(
|
||||
ast.Name("syscall", ast.Load()), loc_node)
|
||||
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
|
||||
arg2 = ast.copy_location(
|
||||
value_to_ast(rpc_mapper.encode(setter)), loc_node)
|
||||
arg3 = ast.copy_location(
|
||||
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
|
||||
call = ast.copy_location(
|
||||
ast.Call(func=func, args=[arg1, arg2, arg3],
|
||||
keywords=[], starargs=None, kwargs=None),
|
||||
loc_node)
|
||||
expr = ast.copy_location(ast.Expr(call), loc_node)
|
||||
attr_writeback.append(expr)
|
||||
return attr_writeback
|
||||
|
||||
|
||||
def inline(core, k_function, k_args, k_kwargs):
|
||||
if k_kwargs:
|
||||
raise NotImplementedError(
|
||||
@ -392,16 +439,8 @@ def inline(core, k_function, k_args, k_kwargs):
|
||||
func=k_function,
|
||||
args=k_args)
|
||||
|
||||
param_init = []
|
||||
for (_, attr), attr_info in attribute_namespace.items():
|
||||
value = getattr(attr_info.obj, attr)
|
||||
value = ast.copy_location(value_to_ast(value), func_def)
|
||||
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
||||
ast.Store()),
|
||||
func_def)
|
||||
assign = ast.copy_location(ast.Assign([target], value),
|
||||
func_def)
|
||||
param_init.append(assign)
|
||||
func_def.body[0:0] = param_init
|
||||
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
||||
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
||||
func_def)
|
||||
|
||||
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
||||
|
@ -36,6 +36,15 @@ class _Primes(AutoContext):
|
||||
self.output_list.append(x)
|
||||
|
||||
|
||||
class _Attributes(AutoContext):
|
||||
def build(self):
|
||||
self.input = 84
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.result = self.input//2
|
||||
|
||||
|
||||
class _PulseLogger(AutoContext):
|
||||
parameters = "output_list name"
|
||||
|
||||
@ -123,13 +132,20 @@ class _Exceptions(AutoContext):
|
||||
self.trace.append(104)
|
||||
|
||||
|
||||
class SimCompareCase(unittest.TestCase):
|
||||
class ExecutionCase(unittest.TestCase):
|
||||
def test_primes(self):
|
||||
l_device, l_host = [], []
|
||||
_run_on_device(_Primes, max=100, output_list=l_device)
|
||||
_run_on_host(_Primes, max=100, output_list=l_host)
|
||||
self.assertEqual(l_device, l_host)
|
||||
|
||||
def test_attributes(self):
|
||||
with comm_serial.Comm() as comm:
|
||||
coredev = core.Core(comm)
|
||||
uut = _Attributes(core=coredev)
|
||||
uut.run()
|
||||
self.assertEqual(uut.result, 42)
|
||||
|
||||
def test_pulses(self):
|
||||
l_device, l_host = [], []
|
||||
_run_on_device(_Pulses, output_list=l_device)
|
||||
|
Loading…
Reference in New Issue
Block a user