forked from M-Labs/artiq
transforms/inline: support attributes on user variables/calls
This commit is contained in:
parent
e22301ea05
commit
2920ac85d2
|
@ -33,62 +33,70 @@ class _HostObjectMapper:
|
||||||
_UserVariable = namedtuple("_UserVariable", "name")
|
_UserVariable = namedtuple("_UserVariable", "name")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_kernel_attr(value, attr):
|
||||||
|
return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split()
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceManager:
|
class _ReferenceManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# (id(obj), func_name, local_name) or (id(obj), kernel_attr_name)
|
|
||||||
# -> _UserVariable(name) / ast / constant_object
|
|
||||||
self.to_inlined = dict()
|
|
||||||
# inlined_name -> use_count
|
|
||||||
self.use_count = dict()
|
|
||||||
self.rpc_mapper = _HostObjectMapper()
|
self.rpc_mapper = _HostObjectMapper()
|
||||||
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
|
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
|
||||||
self.kernel_attr_init = []
|
self.kernel_attr_init = []
|
||||||
|
|
||||||
|
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
|
||||||
|
# -> _UserVariable(name) / ast / constant_object
|
||||||
|
self._to_inlined = dict()
|
||||||
|
# inlined_name -> use_count
|
||||||
|
self._use_count = dict()
|
||||||
# reserved names
|
# reserved names
|
||||||
for kg in core_language.kernel_globals:
|
for kg in core_language.kernel_globals:
|
||||||
self.use_count[kg] = 1
|
self._use_count[kg] = 1
|
||||||
for name in ("int", "round", "int64", "round64", "float", "array",
|
for name in ("int", "round", "int64", "round64", "float", "array",
|
||||||
"range", "Fraction", "Quantity", "EncodedException"):
|
"range", "Fraction", "Quantity", "EncodedException"):
|
||||||
self.use_count[name] = 1
|
self._use_count[name] = 1
|
||||||
|
|
||||||
|
# node_or_value can be a AST node, used to inline function parameter values
|
||||||
|
# that can be simplified later through constant folding.
|
||||||
|
def register_replace(self, obj, func_name, ref_name, node_or_value):
|
||||||
|
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value
|
||||||
|
|
||||||
def new_name(self, base_name):
|
def new_name(self, base_name):
|
||||||
if base_name[-1].isdigit():
|
if base_name[-1].isdigit():
|
||||||
base_name += "_"
|
base_name += "_"
|
||||||
if base_name in self.use_count:
|
if base_name in self._use_count:
|
||||||
r = base_name + str(self.use_count[base_name])
|
r = base_name + str(self._use_count[base_name])
|
||||||
self.use_count[base_name] += 1
|
self._use_count[base_name] += 1
|
||||||
return r
|
return r
|
||||||
else:
|
else:
|
||||||
self.use_count[base_name] = 1
|
self._use_count[base_name] = 1
|
||||||
return base_name
|
return base_name
|
||||||
|
|
||||||
def get(self, obj, func_name, ref):
|
def resolve_name(self, obj, func_name, ref_name, store):
|
||||||
if isinstance(ref, ast.Name):
|
key = (id(obj), func_name, ref_name)
|
||||||
key = (id(obj), func_name, ref.id)
|
|
||||||
try:
|
try:
|
||||||
return self.to_inlined[key]
|
return self._to_inlined[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if isinstance(ref.ctx, ast.Store):
|
if store:
|
||||||
ival = _UserVariable(self.new_name(ref.id))
|
ival = _UserVariable(self.new_name(ref_name))
|
||||||
self.to_inlined[key] = ival
|
self._to_inlined[key] = ival
|
||||||
return ival
|
return ival
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return inspect.getmodule(obj).__dict__[ref.id]
|
return inspect.getmodule(obj).__dict__[ref_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return getattr(builtins, ref.id)
|
return getattr(builtins, ref_name)
|
||||||
elif isinstance(ref, ast.Attribute):
|
|
||||||
target = self.get(obj, func_name, ref.value)
|
def resolve_attr(self, value, attr):
|
||||||
if hasattr(target, "kernel_attr") and ref.attr in target.kernel_attr.split():
|
if _is_kernel_attr(value, attr):
|
||||||
key = (id(target), ref.attr)
|
key = (id(value), attr)
|
||||||
try:
|
try:
|
||||||
ival = self.to_inlined[key]
|
ival = self._to_inlined[key]
|
||||||
assert(isinstance(ival, _UserVariable))
|
assert(isinstance(ival, _UserVariable))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
iname = self.new_name(ref.attr)
|
iname = self.new_name(attr)
|
||||||
ival = _UserVariable(iname)
|
ival = _UserVariable(iname)
|
||||||
self.to_inlined[key] = ival
|
self._to_inlined[key] = ival
|
||||||
a = value_to_ast(getattr(target, ref.attr))
|
a = value_to_ast(getattr(value, attr))
|
||||||
if a is None:
|
if a is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Cannot represent initial value"
|
"Cannot represent initial value"
|
||||||
|
@ -97,7 +105,19 @@ class _ReferenceManager:
|
||||||
[ast.Name(iname, ast.Store())], a))
|
[ast.Name(iname, ast.Store())], a))
|
||||||
return ival
|
return ival
|
||||||
else:
|
else:
|
||||||
return getattr(target, ref.attr)
|
return getattr(value, attr)
|
||||||
|
|
||||||
|
def resolve_constant(self, obj, func_name, node):
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
c = self.resolve_name(obj, func_name, node.id, False)
|
||||||
|
if isinstance(c, (_UserVariable, ast.AST)):
|
||||||
|
raise ValueError("Not a constant")
|
||||||
|
return c
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
value = self.resolve_constant(obj, func_name, node.value)
|
||||||
|
if _is_kernel_attr(value, node.attr):
|
||||||
|
raise ValueError("Not a constant")
|
||||||
|
return getattr(value, node.attr)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -156,9 +176,9 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
setattr(node, field, new_node)
|
setattr(node, field, new_node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_ref(self, node):
|
def visit_Name(self, node):
|
||||||
store = isinstance(node.ctx, ast.Store)
|
store = isinstance(node.ctx, ast.Store)
|
||||||
ival = self.rm.get(self.obj, self.func_name, node)
|
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
|
||||||
if isinstance(ival, _UserVariable):
|
if isinstance(ival, _UserVariable):
|
||||||
newnode = ast.Name(ival.name, node.ctx)
|
newnode = ast.Name(ival.name, node.ctx)
|
||||||
elif isinstance(ival, ast.AST):
|
elif isinstance(ival, ast.AST):
|
||||||
|
@ -175,11 +195,34 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
"Cannot represent inlined value")
|
"Cannot represent inlined value")
|
||||||
return ast.copy_location(newnode, node)
|
return ast.copy_location(newnode, node)
|
||||||
|
|
||||||
visit_Name = visit_ref
|
def _resolve_attribute(self, node):
|
||||||
visit_Attribute = visit_ref
|
if isinstance(node, ast.Name):
|
||||||
|
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False)
|
||||||
|
if isinstance(ival, _UserVariable):
|
||||||
|
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
|
||||||
|
else:
|
||||||
|
return ival
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
value = self._resolve_attribute(node.value)
|
||||||
|
if isinstance(value, ast.AST):
|
||||||
|
node.value = value
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
return self.rm.resolve_attr(value, node.attr)
|
||||||
|
else:
|
||||||
|
return self.visit(node)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node):
|
||||||
|
ival = self._resolve_attribute(node)
|
||||||
|
if isinstance(ival, ast.AST):
|
||||||
|
return ival
|
||||||
|
elif isinstance(ival, _UserVariable):
|
||||||
|
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
|
||||||
|
else:
|
||||||
|
return value_to_ast(ival)
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
func = self.rm.get(self.obj, self.func_name, node.func)
|
func = self.rm.resolve_constant(self.obj, self.func_name, node.func)
|
||||||
new_args = [self.visit(arg) for arg in node.args]
|
new_args = [self.visit(arg) for arg in node.args]
|
||||||
|
|
||||||
if _is_embeddable(func):
|
if _is_embeddable(func):
|
||||||
|
@ -240,7 +283,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def _encode_exception(self, e):
|
def _encode_exception(self, e):
|
||||||
exception_class = self.rm.get(self.obj, self.func_name, e)
|
exception_class = self.rm.resolve_constant(self.obj, self.func_name, e)
|
||||||
if not inspect.isclass(exception_class):
|
if not inspect.isclass(exception_class):
|
||||||
raise NotImplementedError("Exception type must be a class")
|
raise NotImplementedError("Exception type must be a class")
|
||||||
if issubclass(exception_class, core_language.RuntimeException):
|
if issubclass(exception_class, core_language.RuntimeException):
|
||||||
|
@ -301,9 +344,10 @@ def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
||||||
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
||||||
arg_name = arg_ast.arg
|
arg_name = arg_ast.arg
|
||||||
if arg_name in rop:
|
if arg_name in rop:
|
||||||
rm.to_inlined[(id(obj), func_name, arg_name)] = arg_value
|
rm.register_replace(obj, func_name, arg_name, arg_value)
|
||||||
else:
|
else:
|
||||||
target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store()))
|
uservar = rm.resolve_name(obj, func_name, arg_name, True)
|
||||||
|
target = ast.Name(uservar.name, ast.Store())
|
||||||
value = value_to_ast(arg_value)
|
value = value_to_ast(arg_value)
|
||||||
param_init.append(ast.Assign(targets=[target], value=value))
|
param_init.append(ast.Assign(targets=[target], value=value))
|
||||||
return param_init
|
return param_init
|
||||||
|
|
Loading…
Reference in New Issue