forked from M-Labs/artiq
transforms/inline: support non-hashable host objects
This commit is contained in:
parent
82da734e89
commit
8aebab580f
|
@ -1,4 +1,4 @@
|
||||||
from collections import namedtuple, defaultdict
|
from collections import namedtuple
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
|
@ -10,6 +10,26 @@ from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
|
||||||
|
|
||||||
|
class _HostObjectMapper:
|
||||||
|
def __init__(self, first_encoding=0):
|
||||||
|
self._next_encoding = first_encoding
|
||||||
|
# id(object) -> (encoding, object)
|
||||||
|
# this format is required to support non-hashable host objects.
|
||||||
|
self._d = dict()
|
||||||
|
|
||||||
|
def encode(self, obj):
|
||||||
|
try:
|
||||||
|
return self._d[id(obj)][0]
|
||||||
|
except KeyError:
|
||||||
|
encoding = self._next_encoding
|
||||||
|
self._d[id(obj)] = (encoding, obj)
|
||||||
|
self._next_encoding += 1
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
def get_map(self):
|
||||||
|
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
||||||
|
|
||||||
|
|
||||||
_UserVariable = namedtuple("_UserVariable", "name")
|
_UserVariable = namedtuple("_UserVariable", "name")
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +40,9 @@ class _ReferenceManager:
|
||||||
self.to_inlined = dict()
|
self.to_inlined = dict()
|
||||||
# inlined_name -> use_count
|
# inlined_name -> use_count
|
||||||
self.use_count = dict()
|
self.use_count = dict()
|
||||||
self.rpc_map = defaultdict(lambda: len(self.rpc_map))
|
self.rpc_mapper = _HostObjectMapper()
|
||||||
self.exception_map = defaultdict(lambda: len(self.exception_map))
|
# exceptions 0-1023 are for runtime
|
||||||
|
self.exception_mapper = _HostObjectMapper(1024)
|
||||||
self.kernel_attr_init = []
|
self.kernel_attr_init = []
|
||||||
|
|
||||||
# reserved names
|
# reserved names
|
||||||
|
@ -83,13 +104,19 @@ class _ReferenceManager:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
_embeddable_calls = {
|
_embeddable_calls = (
|
||||||
core_language.delay, core_language.at, core_language.now,
|
core_language.delay, core_language.at, core_language.now,
|
||||||
core_language.syscall,
|
core_language.syscall,
|
||||||
range, int, float, round,
|
range, int, float, round,
|
||||||
core_language.int64, core_language.round64, core_language.array,
|
core_language.int64, core_language.round64, core_language.array,
|
||||||
Fraction, units.Quantity, core_language.EncodedException
|
Fraction, units.Quantity, core_language.EncodedException
|
||||||
}
|
)
|
||||||
|
|
||||||
|
def _is_embeddable(call):
|
||||||
|
for ec in _embeddable_calls:
|
||||||
|
if call is ec:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceReplacer(ast.NodeVisitor):
|
class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
|
@ -156,7 +183,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
func = self.rm.get(self.obj, self.func_name, node.func)
|
func = self.rm.get(self.obj, self.func_name, node.func)
|
||||||
new_args = [self.visit(arg) for arg in node.args]
|
new_args = [self.visit(arg) for arg in node.args]
|
||||||
|
|
||||||
if func in _embeddable_calls:
|
if _is_embeddable(func):
|
||||||
new_func = ast.Name(func.__name__, ast.Load())
|
new_func = ast.Name(func.__name__, ast.Load())
|
||||||
return ast.copy_location(
|
return ast.copy_location(
|
||||||
ast.Call(func=new_func, args=new_args,
|
ast.Call(func=new_func, args=new_args,
|
||||||
|
@ -177,7 +204,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
body=inlined.body))
|
body=inlined.body))
|
||||||
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
|
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
|
||||||
else:
|
else:
|
||||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
|
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))]
|
||||||
args += new_args
|
args += new_args
|
||||||
return ast.copy_location(
|
return ast.copy_location(
|
||||||
ast.Call(func=ast.Name("syscall", ast.Load()),
|
ast.Call(func=ast.Name("syscall", ast.Load()),
|
||||||
|
@ -216,7 +243,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
exception_class = self.rm.get(self.obj, self.func_name, node.exc)
|
exception_class = self.rm.get(self.obj, self.func_name, node.exc)
|
||||||
if not inspect.isclass(exception_class):
|
if not inspect.isclass(exception_class):
|
||||||
raise NotImplementedError("Exception must be a class")
|
raise NotImplementedError("Exception must be a class")
|
||||||
exception_id = self.rm.exception_map[exception_class]
|
exception_id = self.rm.exception_mapper.encode(exception_class)
|
||||||
node.exc = ast.copy_location(
|
node.exc = ast.copy_location(
|
||||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||||
args=[value_to_ast(exception_id)],
|
args=[value_to_ast(exception_id)],
|
||||||
|
@ -228,7 +255,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
exception_class = self.rm.get(self.obj, self.func_name, e)
|
exception_class = self.rm.get(self.obj, self.func_name, e)
|
||||||
if not inspect.isclass(exception_class):
|
if not inspect.isclass(exception_class):
|
||||||
raise NotImplementedError("Exception type must be a class")
|
raise NotImplementedError("Exception type must be a class")
|
||||||
exception_id = self.rm.exception_map[exception_class]
|
exception_id = self.rm.exception_mapper.encode(exception_class)
|
||||||
return ast.copy_location(
|
return ast.copy_location(
|
||||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||||
args=[value_to_ast(exception_id)],
|
args=[value_to_ast(exception_id)],
|
||||||
|
@ -302,9 +329,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
|
||||||
if init_kernel_attr:
|
if init_kernel_attr:
|
||||||
func_def.body[0:0] = rm.kernel_attr_init
|
func_def.body[0:0] = rm.kernel_attr_init
|
||||||
|
|
||||||
r_rpc_map = dict((rpc_num, rpc_fun)
|
return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()
|
||||||
for rpc_fun, rpc_num in rm.rpc_map.items())
|
|
||||||
r_exception_map = dict((exception_num, exception_class)
|
|
||||||
for exception_class, exception_num
|
|
||||||
in rm.exception_map.items())
|
|
||||||
return func_def, r_rpc_map, r_exception_map
|
|
||||||
|
|
Loading…
Reference in New Issue