forked from M-Labs/artiq
transforms/inline: support attributes on user variables/calls
This commit is contained in:
parent
e22301ea05
commit
2920ac85d2
|
@ -33,62 +33,70 @@ class _HostObjectMapper:
|
|||
_UserVariable = namedtuple("_UserVariable", "name")
|
||||
|
||||
|
||||
def _is_kernel_attr(value, attr):
|
||||
return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split()
|
||||
|
||||
|
||||
class _ReferenceManager:
|
||||
def __init__(self):
|
||||
# (id(obj), func_name, local_name) or (id(obj), kernel_attr_name)
|
||||
# -> _UserVariable(name) / ast / constant_object
|
||||
self.to_inlined = dict()
|
||||
# inlined_name -> use_count
|
||||
self.use_count = dict()
|
||||
self.rpc_mapper = _HostObjectMapper()
|
||||
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
|
||||
self.kernel_attr_init = []
|
||||
|
||||
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
|
||||
# -> _UserVariable(name) / ast / constant_object
|
||||
self._to_inlined = dict()
|
||||
# inlined_name -> use_count
|
||||
self._use_count = dict()
|
||||
# reserved names
|
||||
for kg in core_language.kernel_globals:
|
||||
self.use_count[kg] = 1
|
||||
self._use_count[kg] = 1
|
||||
for name in ("int", "round", "int64", "round64", "float", "array",
|
||||
"range", "Fraction", "Quantity", "EncodedException"):
|
||||
self.use_count[name] = 1
|
||||
self._use_count[name] = 1
|
||||
|
||||
# node_or_value can be a AST node, used to inline function parameter values
|
||||
# that can be simplified later through constant folding.
|
||||
def register_replace(self, obj, func_name, ref_name, node_or_value):
|
||||
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value
|
||||
|
||||
def new_name(self, base_name):
|
||||
if base_name[-1].isdigit():
|
||||
base_name += "_"
|
||||
if base_name in self.use_count:
|
||||
r = base_name + str(self.use_count[base_name])
|
||||
self.use_count[base_name] += 1
|
||||
if base_name in self._use_count:
|
||||
r = base_name + str(self._use_count[base_name])
|
||||
self._use_count[base_name] += 1
|
||||
return r
|
||||
else:
|
||||
self.use_count[base_name] = 1
|
||||
self._use_count[base_name] = 1
|
||||
return base_name
|
||||
|
||||
def get(self, obj, func_name, ref):
|
||||
if isinstance(ref, ast.Name):
|
||||
key = (id(obj), func_name, ref.id)
|
||||
def resolve_name(self, obj, func_name, ref_name, store):
|
||||
key = (id(obj), func_name, ref_name)
|
||||
try:
|
||||
return self.to_inlined[key]
|
||||
return self._to_inlined[key]
|
||||
except KeyError:
|
||||
if isinstance(ref.ctx, ast.Store):
|
||||
ival = _UserVariable(self.new_name(ref.id))
|
||||
self.to_inlined[key] = ival
|
||||
if store:
|
||||
ival = _UserVariable(self.new_name(ref_name))
|
||||
self._to_inlined[key] = ival
|
||||
return ival
|
||||
else:
|
||||
try:
|
||||
return inspect.getmodule(obj).__dict__[ref.id]
|
||||
return inspect.getmodule(obj).__dict__[ref_name]
|
||||
except KeyError:
|
||||
return getattr(builtins, ref.id)
|
||||
elif isinstance(ref, ast.Attribute):
|
||||
target = self.get(obj, func_name, ref.value)
|
||||
if hasattr(target, "kernel_attr") and ref.attr in target.kernel_attr.split():
|
||||
key = (id(target), ref.attr)
|
||||
return getattr(builtins, ref_name)
|
||||
|
||||
def resolve_attr(self, value, attr):
|
||||
if _is_kernel_attr(value, attr):
|
||||
key = (id(value), attr)
|
||||
try:
|
||||
ival = self.to_inlined[key]
|
||||
ival = self._to_inlined[key]
|
||||
assert(isinstance(ival, _UserVariable))
|
||||
except KeyError:
|
||||
iname = self.new_name(ref.attr)
|
||||
iname = self.new_name(attr)
|
||||
ival = _UserVariable(iname)
|
||||
self.to_inlined[key] = ival
|
||||
a = value_to_ast(getattr(target, ref.attr))
|
||||
self._to_inlined[key] = ival
|
||||
a = value_to_ast(getattr(value, attr))
|
||||
if a is None:
|
||||
raise NotImplementedError(
|
||||
"Cannot represent initial value"
|
||||
|
@ -97,7 +105,19 @@ class _ReferenceManager:
|
|||
[ast.Name(iname, ast.Store())], a))
|
||||
return ival
|
||||
else:
|
||||
return getattr(target, ref.attr)
|
||||
return getattr(value, attr)
|
||||
|
||||
def resolve_constant(self, obj, func_name, node):
|
||||
if isinstance(node, ast.Name):
|
||||
c = self.resolve_name(obj, func_name, node.id, False)
|
||||
if isinstance(c, (_UserVariable, ast.AST)):
|
||||
raise ValueError("Not a constant")
|
||||
return c
|
||||
elif isinstance(node, ast.Attribute):
|
||||
value = self.resolve_constant(obj, func_name, node.value)
|
||||
if _is_kernel_attr(value, node.attr):
|
||||
raise ValueError("Not a constant")
|
||||
return getattr(value, node.attr)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -156,9 +176,9 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
|||
setattr(node, field, new_node)
|
||||
return node
|
||||
|
||||
def visit_ref(self, node):
|
||||
def visit_Name(self, node):
|
||||
store = isinstance(node.ctx, ast.Store)
|
||||
ival = self.rm.get(self.obj, self.func_name, node)
|
||||
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
|
||||
if isinstance(ival, _UserVariable):
|
||||
newnode = ast.Name(ival.name, node.ctx)
|
||||
elif isinstance(ival, ast.AST):
|
||||
|
@ -175,11 +195,34 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
|||
"Cannot represent inlined value")
|
||||
return ast.copy_location(newnode, node)
|
||||
|
||||
visit_Name = visit_ref
|
||||
visit_Attribute = visit_ref
|
||||
def _resolve_attribute(self, node):
|
||||
if isinstance(node, ast.Name):
|
||||
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False)
|
||||
if isinstance(ival, _UserVariable):
|
||||
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
|
||||
else:
|
||||
return ival
|
||||
elif isinstance(node, ast.Attribute):
|
||||
value = self._resolve_attribute(node.value)
|
||||
if isinstance(value, ast.AST):
|
||||
node.value = value
|
||||
return node
|
||||
else:
|
||||
return self.rm.resolve_attr(value, node.attr)
|
||||
else:
|
||||
return self.visit(node)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
ival = self._resolve_attribute(node)
|
||||
if isinstance(ival, ast.AST):
|
||||
return ival
|
||||
elif isinstance(ival, _UserVariable):
|
||||
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
|
||||
else:
|
||||
return value_to_ast(ival)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.rm.get(self.obj, self.func_name, node.func)
|
||||
func = self.rm.resolve_constant(self.obj, self.func_name, node.func)
|
||||
new_args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if _is_embeddable(func):
|
||||
|
@ -240,7 +283,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
|||
return node
|
||||
|
||||
def _encode_exception(self, e):
|
||||
exception_class = self.rm.get(self.obj, self.func_name, e)
|
||||
exception_class = self.rm.resolve_constant(self.obj, self.func_name, e)
|
||||
if not inspect.isclass(exception_class):
|
||||
raise NotImplementedError("Exception type must be a class")
|
||||
if issubclass(exception_class, core_language.RuntimeException):
|
||||
|
@ -301,9 +344,10 @@ def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
|||
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
||||
arg_name = arg_ast.arg
|
||||
if arg_name in rop:
|
||||
rm.to_inlined[(id(obj), func_name, arg_name)] = arg_value
|
||||
rm.register_replace(obj, func_name, arg_name, arg_value)
|
||||
else:
|
||||
target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store()))
|
||||
uservar = rm.resolve_name(obj, func_name, arg_name, True)
|
||||
target = ast.Name(uservar.name, ast.Store())
|
||||
value = value_to_ast(arg_value)
|
||||
param_init.append(ast.Assign(targets=[target], value=value))
|
||||
return param_init
|
||||
|
|
Loading…
Reference in New Issue