forked from M-Labs/artiq
transforms/inline: object attribute writeback
This commit is contained in:
parent
f54a2f93d2
commit
e9e12adceb
|
@ -5,6 +5,7 @@ import types
|
||||||
import builtins
|
import builtins
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from artiq.language import core as core_language
|
from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
@ -371,6 +372,52 @@ class HostObjectMapper:
|
||||||
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr_init(attribute_namespace, loc_node):
|
||||||
|
attr_init = []
|
||||||
|
for (_, attr), attr_info in attribute_namespace.items():
|
||||||
|
if hasattr(attr_info.obj, attr):
|
||||||
|
value = getattr(attr_info.obj, attr)
|
||||||
|
value = ast.copy_location(value_to_ast(value), loc_node)
|
||||||
|
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
||||||
|
ast.Store()),
|
||||||
|
loc_node)
|
||||||
|
assign = ast.copy_location(ast.Assign([target], value),
|
||||||
|
loc_node)
|
||||||
|
attr_init.append(assign)
|
||||||
|
return attr_init
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||||
|
attr_writeback = []
|
||||||
|
for (_, attr), attr_info in attribute_namespace.items():
|
||||||
|
if attr_info.read_write:
|
||||||
|
# HACK/FIXME: since RPC of non-int is not supported yet, skip
|
||||||
|
# writeback of other types for now.
|
||||||
|
# This code breaks if an int is promoted to int64
|
||||||
|
if hasattr(attr_info.obj, attr):
|
||||||
|
val = getattr(attr_info.obj, attr)
|
||||||
|
if (not isinstance(val, int)
|
||||||
|
or isinstance(val, core_language.int64)):
|
||||||
|
continue
|
||||||
|
#
|
||||||
|
|
||||||
|
setter = partial(setattr, attr_info.obj, attr)
|
||||||
|
func = ast.copy_location(
|
||||||
|
ast.Name("syscall", ast.Load()), loc_node)
|
||||||
|
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
|
||||||
|
arg2 = ast.copy_location(
|
||||||
|
value_to_ast(rpc_mapper.encode(setter)), loc_node)
|
||||||
|
arg3 = ast.copy_location(
|
||||||
|
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
|
||||||
|
call = ast.copy_location(
|
||||||
|
ast.Call(func=func, args=[arg1, arg2, arg3],
|
||||||
|
keywords=[], starargs=None, kwargs=None),
|
||||||
|
loc_node)
|
||||||
|
expr = ast.copy_location(ast.Expr(call), loc_node)
|
||||||
|
attr_writeback.append(expr)
|
||||||
|
return attr_writeback
|
||||||
|
|
||||||
|
|
||||||
def inline(core, k_function, k_args, k_kwargs):
|
def inline(core, k_function, k_args, k_kwargs):
|
||||||
if k_kwargs:
|
if k_kwargs:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -392,16 +439,8 @@ def inline(core, k_function, k_args, k_kwargs):
|
||||||
func=k_function,
|
func=k_function,
|
||||||
args=k_args)
|
args=k_args)
|
||||||
|
|
||||||
param_init = []
|
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
||||||
for (_, attr), attr_info in attribute_namespace.items():
|
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
||||||
value = getattr(attr_info.obj, attr)
|
|
||||||
value = ast.copy_location(value_to_ast(value), func_def)
|
|
||||||
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
|
||||||
ast.Store()),
|
|
||||||
func_def)
|
func_def)
|
||||||
assign = ast.copy_location(ast.Assign([target], value),
|
|
||||||
func_def)
|
|
||||||
param_init.append(assign)
|
|
||||||
func_def.body[0:0] = param_init
|
|
||||||
|
|
||||||
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
||||||
|
|
|
@ -36,6 +36,15 @@ class _Primes(AutoContext):
|
||||||
self.output_list.append(x)
|
self.output_list.append(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _Attributes(AutoContext):
|
||||||
|
def build(self):
|
||||||
|
self.input = 84
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def run(self):
|
||||||
|
self.result = self.input//2
|
||||||
|
|
||||||
|
|
||||||
class _PulseLogger(AutoContext):
|
class _PulseLogger(AutoContext):
|
||||||
parameters = "output_list name"
|
parameters = "output_list name"
|
||||||
|
|
||||||
|
@ -123,13 +132,20 @@ class _Exceptions(AutoContext):
|
||||||
self.trace.append(104)
|
self.trace.append(104)
|
||||||
|
|
||||||
|
|
||||||
class SimCompareCase(unittest.TestCase):
|
class ExecutionCase(unittest.TestCase):
|
||||||
def test_primes(self):
|
def test_primes(self):
|
||||||
l_device, l_host = [], []
|
l_device, l_host = [], []
|
||||||
_run_on_device(_Primes, max=100, output_list=l_device)
|
_run_on_device(_Primes, max=100, output_list=l_device)
|
||||||
_run_on_host(_Primes, max=100, output_list=l_host)
|
_run_on_host(_Primes, max=100, output_list=l_host)
|
||||||
self.assertEqual(l_device, l_host)
|
self.assertEqual(l_device, l_host)
|
||||||
|
|
||||||
|
def test_attributes(self):
|
||||||
|
with comm_serial.Comm() as comm:
|
||||||
|
coredev = core.Core(comm)
|
||||||
|
uut = _Attributes(core=coredev)
|
||||||
|
uut.run()
|
||||||
|
self.assertEqual(uut.result, 42)
|
||||||
|
|
||||||
def test_pulses(self):
|
def test_pulses(self):
|
||||||
l_device, l_host = [], []
|
l_device, l_host = [], []
|
||||||
_run_on_device(_Pulses, output_list=l_device)
|
_run_on_device(_Pulses, output_list=l_device)
|
||||||
|
|
Loading…
Reference in New Issue