forked from M-Labs/artiq
transforms/inline: support non-hashable host objects
This commit is contained in:
parent
82da734e89
commit
8aebab580f
@ -1,4 +1,4 @@
|
||||
from collections import namedtuple, defaultdict
|
||||
from collections import namedtuple
|
||||
from fractions import Fraction
|
||||
import inspect
|
||||
import textwrap
|
||||
@ -10,6 +10,26 @@ from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
|
||||
|
||||
class _HostObjectMapper:
|
||||
def __init__(self, first_encoding=0):
|
||||
self._next_encoding = first_encoding
|
||||
# id(object) -> (encoding, object)
|
||||
# this format is required to support non-hashable host objects.
|
||||
self._d = dict()
|
||||
|
||||
def encode(self, obj):
|
||||
try:
|
||||
return self._d[id(obj)][0]
|
||||
except KeyError:
|
||||
encoding = self._next_encoding
|
||||
self._d[id(obj)] = (encoding, obj)
|
||||
self._next_encoding += 1
|
||||
return encoding
|
||||
|
||||
def get_map(self):
|
||||
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
||||
|
||||
|
||||
_UserVariable = namedtuple("_UserVariable", "name")
|
||||
|
||||
|
||||
@ -20,8 +40,9 @@ class _ReferenceManager:
|
||||
self.to_inlined = dict()
|
||||
# inlined_name -> use_count
|
||||
self.use_count = dict()
|
||||
self.rpc_map = defaultdict(lambda: len(self.rpc_map))
|
||||
self.exception_map = defaultdict(lambda: len(self.exception_map))
|
||||
self.rpc_mapper = _HostObjectMapper()
|
||||
# exceptions 0-1023 are for runtime
|
||||
self.exception_mapper = _HostObjectMapper(1024)
|
||||
self.kernel_attr_init = []
|
||||
|
||||
# reserved names
|
||||
@ -83,13 +104,19 @@ class _ReferenceManager:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
_embeddable_calls = {
|
||||
_embeddable_calls = (
|
||||
core_language.delay, core_language.at, core_language.now,
|
||||
core_language.syscall,
|
||||
range, int, float, round,
|
||||
core_language.int64, core_language.round64, core_language.array,
|
||||
Fraction, units.Quantity, core_language.EncodedException
|
||||
}
|
||||
)
|
||||
|
||||
def _is_embeddable(call):
|
||||
for ec in _embeddable_calls:
|
||||
if call is ec:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _ReferenceReplacer(ast.NodeVisitor):
|
||||
@ -156,7 +183,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||
func = self.rm.get(self.obj, self.func_name, node.func)
|
||||
new_args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in _embeddable_calls:
|
||||
if _is_embeddable(func):
|
||||
new_func = ast.Name(func.__name__, ast.Load())
|
||||
return ast.copy_location(
|
||||
ast.Call(func=new_func, args=new_args,
|
||||
@ -177,7 +204,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||
body=inlined.body))
|
||||
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
|
||||
else:
|
||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))]
|
||||
args += new_args
|
||||
return ast.copy_location(
|
||||
ast.Call(func=ast.Name("syscall", ast.Load()),
|
||||
@ -216,7 +243,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||
exception_class = self.rm.get(self.obj, self.func_name, node.exc)
|
||||
if not inspect.isclass(exception_class):
|
||||
raise NotImplementedError("Exception must be a class")
|
||||
exception_id = self.rm.exception_map[exception_class]
|
||||
exception_id = self.rm.exception_mapper.encode(exception_class)
|
||||
node.exc = ast.copy_location(
|
||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||
args=[value_to_ast(exception_id)],
|
||||
@ -228,7 +255,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||
exception_class = self.rm.get(self.obj, self.func_name, e)
|
||||
if not inspect.isclass(exception_class):
|
||||
raise NotImplementedError("Exception type must be a class")
|
||||
exception_id = self.rm.exception_map[exception_class]
|
||||
exception_id = self.rm.exception_mapper.encode(exception_class)
|
||||
return ast.copy_location(
|
||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||
args=[value_to_ast(exception_id)],
|
||||
@ -302,9 +329,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
|
||||
if init_kernel_attr:
|
||||
func_def.body[0:0] = rm.kernel_attr_init
|
||||
|
||||
r_rpc_map = dict((rpc_num, rpc_fun)
|
||||
for rpc_fun, rpc_num in rm.rpc_map.items())
|
||||
r_exception_map = dict((exception_num, exception_class)
|
||||
for exception_class, exception_num
|
||||
in rm.exception_map.items())
|
||||
return func_def, r_rpc_map, r_exception_map
|
||||
return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()
|
||||
|
Loading…
Reference in New Issue
Block a user