forked from M-Labs/artiq
transforms/inline: offload some work to remove_inter_assigns/remove_dead_code
This commit is contained in:
parent
1c0c0b691e
commit
97329b7fc9
|
@ -6,7 +6,7 @@ import ast
|
|||
import builtins
|
||||
from copy import deepcopy
|
||||
|
||||
from artiq.transforms.tools import eval_ast, value_to_ast
|
||||
from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable
|
||||
from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
|
||||
|
@ -45,7 +45,7 @@ class _ReferenceManager:
|
|||
self.kernel_attr_init = []
|
||||
|
||||
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
|
||||
# -> _UserVariable(name) / ast / constant_object
|
||||
# -> _UserVariable(name) / complex object
|
||||
self._to_inlined = dict()
|
||||
# inlined_name -> use_count
|
||||
self._use_count = dict()
|
||||
|
@ -57,10 +57,15 @@ class _ReferenceManager:
|
|||
"range"):
|
||||
self._use_count[name] = 1
|
||||
|
||||
# node_or_value can be a AST node, used to inline function parameter values
|
||||
# that can be simplified later through constant folding.
|
||||
def register_replace(self, obj, func_name, ref_name, node_or_value):
|
||||
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value
|
||||
# Complex objects in the namespace of functions can be used in two ways:
|
||||
# 1. Calling a method on them (which gets inlined or RPCd)
|
||||
# 2. Getting or setting attributes (which are turned into local variables)
|
||||
# They are needed to implement "self", which is the only supported use
|
||||
# case.
|
||||
def register_complex_object(self, obj, func_name, ref_name,
|
||||
complex_object):
|
||||
assert(not isinstance(complex_object, ast.AST))
|
||||
self._to_inlined[(id(obj), func_name, ref_name)] = complex_object
|
||||
|
||||
def new_name(self, base_name):
|
||||
if base_name[-1].isdigit():
|
||||
|
@ -112,7 +117,7 @@ class _ReferenceManager:
|
|||
def resolve_constant(self, obj, func_name, node):
|
||||
if isinstance(node, ast.Name):
|
||||
c = self.resolve_name(obj, func_name, node.id, False)
|
||||
if isinstance(c, (_UserVariable, ast.AST)):
|
||||
if isinstance(c, _UserVariable):
|
||||
raise ValueError("Not a constant")
|
||||
return c
|
||||
elif isinstance(node, ast.Attribute):
|
||||
|
@ -192,18 +197,10 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
|||
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
|
||||
if isinstance(ival, _UserVariable):
|
||||
newnode = ast.Name(ival.name, node.ctx)
|
||||
elif isinstance(ival, ast.AST):
|
||||
assert(not store)
|
||||
newnode = deepcopy(ival)
|
||||
else:
|
||||
if store:
|
||||
raise NotImplementedError(
|
||||
"Cannot turn object into user variable")
|
||||
else:
|
||||
newnode = value_to_ast(ival)
|
||||
if newnode is None:
|
||||
raise NotImplementedError(
|
||||
"Cannot represent inlined value")
|
||||
raise NotImplementedError("Cannot assign to this object")
|
||||
newnode = value_to_ast(ival)
|
||||
return ast.copy_location(newnode, node)
|
||||
|
||||
def _resolve_attribute(self, node):
|
||||
|
@ -325,40 +322,23 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
|||
return node
|
||||
|
||||
|
||||
class _ListReadOnlyParams(ast.NodeVisitor):
|
||||
def visit_FunctionDef(self, node):
|
||||
if hasattr(self, "read_only_params"):
|
||||
raise ValueError("More than one function definition")
|
||||
self.read_only_params = {arg.arg for arg in node.args.args}
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
try:
|
||||
self.read_only_params.remove(node.id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def _list_read_only_params(func_def):
|
||||
lrp = _ListReadOnlyParams()
|
||||
lrp.visit(func_def)
|
||||
return lrp.read_only_params
|
||||
|
||||
|
||||
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
||||
obj = k_args[0]
|
||||
func_name = func_def.name
|
||||
param_init = []
|
||||
rop = _list_read_only_params(func_def)
|
||||
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
||||
arg_name = arg_ast.arg
|
||||
if arg_name in rop:
|
||||
rm.register_replace(obj, func_name, arg_name, arg_value)
|
||||
if isinstance(arg_value, ast.AST):
|
||||
value = arg_value
|
||||
else:
|
||||
uservar = rm.resolve_name(obj, func_name, arg_name, True)
|
||||
try:
|
||||
value = value_to_ast(arg_value)
|
||||
except NotASTRepresentable:
|
||||
value = None
|
||||
if value is None:
|
||||
rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value)
|
||||
else:
|
||||
uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True)
|
||||
target = ast.Name(uservar.name, ast.Store())
|
||||
value = value_to_ast(arg_value)
|
||||
param_init.append(ast.Assign(targets=[target], value=value))
|
||||
return param_init
|
||||
|
||||
|
|
|
@ -13,6 +13,10 @@ def eval_ast(expr, symdict=dict()):
|
|||
return eval(code, symdict)
|
||||
|
||||
|
||||
class NotASTRepresentable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def value_to_ast(value):
|
||||
if isinstance(value, core_language.int64): # must be before int
|
||||
return ast.Call(
|
||||
|
@ -41,7 +45,7 @@ def value_to_ast(value):
|
|||
func=ast.Name("Quantity", ast.Load()),
|
||||
args=[value_to_ast(value.amount), ast.Str(value.unit)],
|
||||
keywords=[], starargs=None, kwargs=None)
|
||||
return None
|
||||
raise NotASTRepresentable(str(value))
|
||||
|
||||
|
||||
class NotConstant(Exception):
|
||||
|
|
Loading…
Reference in New Issue