forked from M-Labs/artiq
1
0
Fork 0

transforms/inline: support non-hashable host objects

This commit is contained in:
Sebastien Bourdeauducq 2014-09-24 17:16:40 +08:00
parent 82da734e89
commit 8aebab580f
1 changed files with 37 additions and 15 deletions

View File

@ -1,4 +1,4 @@
from collections import namedtuple, defaultdict from collections import namedtuple
from fractions import Fraction from fractions import Fraction
import inspect import inspect
import textwrap import textwrap
@ -10,6 +10,26 @@ from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
class _HostObjectMapper:
def __init__(self, first_encoding=0):
self._next_encoding = first_encoding
# id(object) -> (encoding, object)
# this format is required to support non-hashable host objects.
self._d = dict()
def encode(self, obj):
try:
return self._d[id(obj)][0]
except KeyError:
encoding = self._next_encoding
self._d[id(obj)] = (encoding, obj)
self._next_encoding += 1
return encoding
def get_map(self):
return {encoding: obj for i, (encoding, obj) in self._d.items()}
_UserVariable = namedtuple("_UserVariable", "name") _UserVariable = namedtuple("_UserVariable", "name")
@ -20,8 +40,9 @@ class _ReferenceManager:
self.to_inlined = dict() self.to_inlined = dict()
# inlined_name -> use_count # inlined_name -> use_count
self.use_count = dict() self.use_count = dict()
self.rpc_map = defaultdict(lambda: len(self.rpc_map)) self.rpc_mapper = _HostObjectMapper()
self.exception_map = defaultdict(lambda: len(self.exception_map)) # exceptions 0-1023 are for runtime
self.exception_mapper = _HostObjectMapper(1024)
self.kernel_attr_init = [] self.kernel_attr_init = []
# reserved names # reserved names
@ -83,13 +104,19 @@ class _ReferenceManager:
raise NotImplementedError raise NotImplementedError
_embeddable_calls = { _embeddable_calls = (
core_language.delay, core_language.at, core_language.now, core_language.delay, core_language.at, core_language.now,
core_language.syscall, core_language.syscall,
range, int, float, round, range, int, float, round,
core_language.int64, core_language.round64, core_language.array, core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException Fraction, units.Quantity, core_language.EncodedException
} )
def _is_embeddable(call):
for ec in _embeddable_calls:
if call is ec:
return True
return False
class _ReferenceReplacer(ast.NodeVisitor): class _ReferenceReplacer(ast.NodeVisitor):
@ -156,7 +183,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
func = self.rm.get(self.obj, self.func_name, node.func) func = self.rm.get(self.obj, self.func_name, node.func)
new_args = [self.visit(arg) for arg in node.args] new_args = [self.visit(arg) for arg in node.args]
if func in _embeddable_calls: if _is_embeddable(func):
new_func = ast.Name(func.__name__, ast.Load()) new_func = ast.Name(func.__name__, ast.Load())
return ast.copy_location( return ast.copy_location(
ast.Call(func=new_func, args=new_args, ast.Call(func=new_func, args=new_args,
@ -177,7 +204,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
body=inlined.body)) body=inlined.body))
return ast.copy_location(ast.Name(retval_name, ast.Load()), node) return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
else: else:
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))]
args += new_args args += new_args
return ast.copy_location( return ast.copy_location(
ast.Call(func=ast.Name("syscall", ast.Load()), ast.Call(func=ast.Name("syscall", ast.Load()),
@ -216,7 +243,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
exception_class = self.rm.get(self.obj, self.func_name, node.exc) exception_class = self.rm.get(self.obj, self.func_name, node.exc)
if not inspect.isclass(exception_class): if not inspect.isclass(exception_class):
raise NotImplementedError("Exception must be a class") raise NotImplementedError("Exception must be a class")
exception_id = self.rm.exception_map[exception_class] exception_id = self.rm.exception_mapper.encode(exception_class)
node.exc = ast.copy_location( node.exc = ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()), ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)], args=[value_to_ast(exception_id)],
@ -228,7 +255,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
exception_class = self.rm.get(self.obj, self.func_name, e) exception_class = self.rm.get(self.obj, self.func_name, e)
if not inspect.isclass(exception_class): if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class") raise NotImplementedError("Exception type must be a class")
exception_id = self.rm.exception_map[exception_class] exception_id = self.rm.exception_mapper.encode(exception_class)
return ast.copy_location( return ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()), ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)], args=[value_to_ast(exception_id)],
@ -302,9 +329,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
if init_kernel_attr: if init_kernel_attr:
func_def.body[0:0] = rm.kernel_attr_init func_def.body[0:0] = rm.kernel_attr_init
r_rpc_map = dict((rpc_num, rpc_fun) return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()
for rpc_fun, rpc_num in rm.rpc_map.items())
r_exception_map = dict((exception_num, exception_class)
for exception_class, exception_num
in rm.exception_map.items())
return func_def, r_rpc_map, r_exception_map