From 8aebab580f9bda107644f35bf632ce8eed2259a6 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Wed, 24 Sep 2014 17:16:40 +0800 Subject: [PATCH] transforms/inline: support non-hashable host objects --- artiq/transforms/inline.py | 52 +++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index 5e7591a0e..1f3a63d9c 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -1,4 +1,4 @@ -from collections import namedtuple, defaultdict +from collections import namedtuple from fractions import Fraction import inspect import textwrap @@ -10,6 +10,26 @@ from artiq.language import core as core_language from artiq.language import units +class _HostObjectMapper: + def __init__(self, first_encoding=0): + self._next_encoding = first_encoding + # id(object) -> (encoding, object) + # this format is required to support non-hashable host objects. + self._d = dict() + + def encode(self, obj): + try: + return self._d[id(obj)][0] + except KeyError: + encoding = self._next_encoding + self._d[id(obj)] = (encoding, obj) + self._next_encoding += 1 + return encoding + + def get_map(self): + return {encoding: obj for i, (encoding, obj) in self._d.items()} + + _UserVariable = namedtuple("_UserVariable", "name") @@ -20,8 +40,9 @@ class _ReferenceManager: self.to_inlined = dict() # inlined_name -> use_count self.use_count = dict() - self.rpc_map = defaultdict(lambda: len(self.rpc_map)) - self.exception_map = defaultdict(lambda: len(self.exception_map)) + self.rpc_mapper = _HostObjectMapper() + # exceptions 0-1023 are for runtime + self.exception_mapper = _HostObjectMapper(1024) self.kernel_attr_init = [] # reserved names @@ -83,13 +104,19 @@ class _ReferenceManager: raise NotImplementedError -_embeddable_calls = { +_embeddable_calls = ( core_language.delay, core_language.at, core_language.now, core_language.syscall, range, int, float, round, core_language.int64, core_language.round64, core_language.array, Fraction, units.Quantity, core_language.EncodedException -} +) + +def _is_embeddable(call): + for ec in _embeddable_calls: + if call is ec: + return True + return False class _ReferenceReplacer(ast.NodeVisitor): @@ -156,7 +183,7 @@ class _ReferenceReplacer(ast.NodeVisitor): func = self.rm.get(self.obj, self.func_name, node.func) new_args = [self.visit(arg) for arg in node.args] - if func in _embeddable_calls: + if _is_embeddable(func): new_func = ast.Name(func.__name__, ast.Load()) return ast.copy_location( ast.Call(func=new_func, args=new_args, @@ -177,7 +204,7 @@ class _ReferenceReplacer(ast.NodeVisitor): body=inlined.body)) return ast.copy_location(ast.Name(retval_name, ast.Load()), node) else: - args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] + args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))] args += new_args return ast.copy_location( ast.Call(func=ast.Name("syscall", ast.Load()), @@ -216,7 +243,7 @@ class _ReferenceReplacer(ast.NodeVisitor): exception_class = self.rm.get(self.obj, self.func_name, node.exc) if not inspect.isclass(exception_class): raise NotImplementedError("Exception must be a class") - exception_id = self.rm.exception_map[exception_class] + exception_id = self.rm.exception_mapper.encode(exception_class) node.exc = ast.copy_location( ast.Call(func=ast.Name("EncodedException", ast.Load()), args=[value_to_ast(exception_id)], @@ -228,7 +255,7 @@ class _ReferenceReplacer(ast.NodeVisitor): exception_class = self.rm.get(self.obj, self.func_name, e) if not inspect.isclass(exception_class): raise NotImplementedError("Exception type must be a class") - exception_id = self.rm.exception_map[exception_class] + exception_id = self.rm.exception_mapper.encode(exception_class) return ast.copy_location( ast.Call(func=ast.Name("EncodedException", ast.Load()), args=[value_to_ast(exception_id)], @@ -302,9 +329,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None): if init_kernel_attr: func_def.body[0:0] = rm.kernel_attr_init - r_rpc_map = dict((rpc_num, rpc_fun) - for rpc_fun, rpc_num in rm.rpc_map.items()) - r_exception_map = dict((exception_num, exception_class) - for exception_class, exception_num - in rm.exception_map.items()) - return func_def, r_rpc_map, r_exception_map + return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()