2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-25 03:08:27 +08:00

transforms/inline: object attribute writeback

This commit is contained in:
Sebastien Bourdeauducq 2014-11-03 18:04:01 +08:00
parent f54a2f93d2
commit e9e12adceb
2 changed files with 67 additions and 12 deletions

View File

@ -5,6 +5,7 @@ import types
import builtins
from fractions import Fraction
from collections import OrderedDict
from functools import partial
from artiq.language import core as core_language
from artiq.language import units
@ -371,6 +372,52 @@ class HostObjectMapper:
return {encoding: obj for i, (encoding, obj) in self._d.items()}
def get_attr_init(attribute_namespace, loc_node):
attr_init = []
for (_, attr), attr_info in attribute_namespace.items():
if hasattr(attr_info.obj, attr):
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), loc_node)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
loc_node)
assign = ast.copy_location(ast.Assign([target], value),
loc_node)
attr_init.append(assign)
return attr_init
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
attr_writeback = []
for (_, attr), attr_info in attribute_namespace.items():
if attr_info.read_write:
# HACK/FIXME: since RPC of non-int is not supported yet, skip
# writeback of other types for now.
# This code breaks if an int is promoted to int64
if hasattr(attr_info.obj, attr):
val = getattr(attr_info.obj, attr)
if (not isinstance(val, int)
or isinstance(val, core_language.int64)):
continue
#
setter = partial(setattr, attr_info.obj, attr)
func = ast.copy_location(
ast.Name("syscall", ast.Load()), loc_node)
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
arg2 = ast.copy_location(
value_to_ast(rpc_mapper.encode(setter)), loc_node)
arg3 = ast.copy_location(
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
call = ast.copy_location(
ast.Call(func=func, args=[arg1, arg2, arg3],
keywords=[], starargs=None, kwargs=None),
loc_node)
expr = ast.copy_location(ast.Expr(call), loc_node)
attr_writeback.append(expr)
return attr_writeback
def inline(core, k_function, k_args, k_kwargs):
if k_kwargs:
raise NotImplementedError(
@ -392,16 +439,8 @@ def inline(core, k_function, k_args, k_kwargs):
func=k_function,
args=k_args)
param_init = []
for (_, attr), attr_info in attribute_namespace.items():
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), func_def)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
param_init.append(assign)
func_def.body[0:0] = param_init
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
func_def)
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()

View File

@ -36,6 +36,15 @@ class _Primes(AutoContext):
self.output_list.append(x)
class _Attributes(AutoContext):
def build(self):
self.input = 84
@kernel
def run(self):
self.result = self.input//2
class _PulseLogger(AutoContext):
parameters = "output_list name"
@ -123,13 +132,20 @@ class _Exceptions(AutoContext):
self.trace.append(104)
class SimCompareCase(unittest.TestCase):
class ExecutionCase(unittest.TestCase):
def test_primes(self):
l_device, l_host = [], []
_run_on_device(_Primes, max=100, output_list=l_device)
_run_on_host(_Primes, max=100, output_list=l_host)
self.assertEqual(l_device, l_host)
def test_attributes(self):
with comm_serial.Comm() as comm:
coredev = core.Core(comm)
uut = _Attributes(core=coredev)
uut.run()
self.assertEqual(uut.result, 42)
def test_pulses(self):
l_device, l_host = [], []
_run_on_device(_Pulses, output_list=l_device)