transforms/inline: object attribute writeback

This commit is contained in:
Sebastien Bourdeauducq 2014-11-03 18:04:01 +08:00
parent f54a2f93d2
commit e9e12adceb
2 changed files with 67 additions and 12 deletions

View File

@ -5,6 +5,7 @@ import types
import builtins import builtins
from fractions import Fraction from fractions import Fraction
from collections import OrderedDict from collections import OrderedDict
from functools import partial
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
@ -371,6 +372,52 @@ class HostObjectMapper:
return {encoding: obj for i, (encoding, obj) in self._d.items()} return {encoding: obj for i, (encoding, obj) in self._d.items()}
def get_attr_init(attribute_namespace, loc_node):
attr_init = []
for (_, attr), attr_info in attribute_namespace.items():
if hasattr(attr_info.obj, attr):
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), loc_node)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
loc_node)
assign = ast.copy_location(ast.Assign([target], value),
loc_node)
attr_init.append(assign)
return attr_init
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
attr_writeback = []
for (_, attr), attr_info in attribute_namespace.items():
if attr_info.read_write:
# HACK/FIXME: since RPC of non-int is not supported yet, skip
# writeback of other types for now.
# This code breaks if an int is promoted to int64
if hasattr(attr_info.obj, attr):
val = getattr(attr_info.obj, attr)
if (not isinstance(val, int)
or isinstance(val, core_language.int64)):
continue
#
setter = partial(setattr, attr_info.obj, attr)
func = ast.copy_location(
ast.Name("syscall", ast.Load()), loc_node)
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
arg2 = ast.copy_location(
value_to_ast(rpc_mapper.encode(setter)), loc_node)
arg3 = ast.copy_location(
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
call = ast.copy_location(
ast.Call(func=func, args=[arg1, arg2, arg3],
keywords=[], starargs=None, kwargs=None),
loc_node)
expr = ast.copy_location(ast.Expr(call), loc_node)
attr_writeback.append(expr)
return attr_writeback
def inline(core, k_function, k_args, k_kwargs): def inline(core, k_function, k_args, k_kwargs):
if k_kwargs: if k_kwargs:
raise NotImplementedError( raise NotImplementedError(
@ -392,16 +439,8 @@ def inline(core, k_function, k_args, k_kwargs):
func=k_function, func=k_function,
args=k_args) args=k_args)
param_init = [] func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
for (_, attr), attr_info in attribute_namespace.items(): func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), func_def)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
func_def) func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
param_init.append(assign)
func_def.body[0:0] = param_init
return func_def, mappers.rpc.get_map(), mappers.exception.get_map() return func_def, mappers.rpc.get_map(), mappers.exception.get_map()

View File

@ -36,6 +36,15 @@ class _Primes(AutoContext):
self.output_list.append(x) self.output_list.append(x)
class _Attributes(AutoContext):
def build(self):
self.input = 84
@kernel
def run(self):
self.result = self.input//2
class _PulseLogger(AutoContext): class _PulseLogger(AutoContext):
parameters = "output_list name" parameters = "output_list name"
@ -123,13 +132,20 @@ class _Exceptions(AutoContext):
self.trace.append(104) self.trace.append(104)
class SimCompareCase(unittest.TestCase): class ExecutionCase(unittest.TestCase):
def test_primes(self): def test_primes(self):
l_device, l_host = [], [] l_device, l_host = [], []
_run_on_device(_Primes, max=100, output_list=l_device) _run_on_device(_Primes, max=100, output_list=l_device)
_run_on_host(_Primes, max=100, output_list=l_host) _run_on_host(_Primes, max=100, output_list=l_host)
self.assertEqual(l_device, l_host) self.assertEqual(l_device, l_host)
def test_attributes(self):
with comm_serial.Comm() as comm:
coredev = core.Core(comm)
uut = _Attributes(core=coredev)
uut.run()
self.assertEqual(uut.result, 42)
def test_pulses(self): def test_pulses(self):
l_device, l_host = [], [] l_device, l_host = [], []
_run_on_device(_Pulses, output_list=l_device) _run_on_device(_Pulses, output_list=l_device)