forked from M-Labs/artiq
transforms/inline: offload some work to remove_inter_assigns/remove_dead_code
This commit is contained in:
parent
1c0c0b691e
commit
97329b7fc9
|
@ -6,7 +6,7 @@ import ast
|
||||||
import builtins
|
import builtins
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from artiq.transforms.tools import eval_ast, value_to_ast
|
from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable
|
||||||
from artiq.language import core as core_language
|
from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class _ReferenceManager:
|
||||||
self.kernel_attr_init = []
|
self.kernel_attr_init = []
|
||||||
|
|
||||||
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
|
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
|
||||||
# -> _UserVariable(name) / ast / constant_object
|
# -> _UserVariable(name) / complex object
|
||||||
self._to_inlined = dict()
|
self._to_inlined = dict()
|
||||||
# inlined_name -> use_count
|
# inlined_name -> use_count
|
||||||
self._use_count = dict()
|
self._use_count = dict()
|
||||||
|
@ -57,10 +57,15 @@ class _ReferenceManager:
|
||||||
"range"):
|
"range"):
|
||||||
self._use_count[name] = 1
|
self._use_count[name] = 1
|
||||||
|
|
||||||
# node_or_value can be a AST node, used to inline function parameter values
|
# Complex objects in the namespace of functions can be used in two ways:
|
||||||
# that can be simplified later through constant folding.
|
# 1. Calling a method on them (which gets inlined or RPCd)
|
||||||
def register_replace(self, obj, func_name, ref_name, node_or_value):
|
# 2. Getting or setting attributes (which are turned into local variables)
|
||||||
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value
|
# They are needed to implement "self", which is the only supported use
|
||||||
|
# case.
|
||||||
|
def register_complex_object(self, obj, func_name, ref_name,
|
||||||
|
complex_object):
|
||||||
|
assert(not isinstance(complex_object, ast.AST))
|
||||||
|
self._to_inlined[(id(obj), func_name, ref_name)] = complex_object
|
||||||
|
|
||||||
def new_name(self, base_name):
|
def new_name(self, base_name):
|
||||||
if base_name[-1].isdigit():
|
if base_name[-1].isdigit():
|
||||||
|
@ -112,7 +117,7 @@ class _ReferenceManager:
|
||||||
def resolve_constant(self, obj, func_name, node):
|
def resolve_constant(self, obj, func_name, node):
|
||||||
if isinstance(node, ast.Name):
|
if isinstance(node, ast.Name):
|
||||||
c = self.resolve_name(obj, func_name, node.id, False)
|
c = self.resolve_name(obj, func_name, node.id, False)
|
||||||
if isinstance(c, (_UserVariable, ast.AST)):
|
if isinstance(c, _UserVariable):
|
||||||
raise ValueError("Not a constant")
|
raise ValueError("Not a constant")
|
||||||
return c
|
return c
|
||||||
elif isinstance(node, ast.Attribute):
|
elif isinstance(node, ast.Attribute):
|
||||||
|
@ -192,18 +197,10 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
|
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
|
||||||
if isinstance(ival, _UserVariable):
|
if isinstance(ival, _UserVariable):
|
||||||
newnode = ast.Name(ival.name, node.ctx)
|
newnode = ast.Name(ival.name, node.ctx)
|
||||||
elif isinstance(ival, ast.AST):
|
|
||||||
assert(not store)
|
|
||||||
newnode = deepcopy(ival)
|
|
||||||
else:
|
else:
|
||||||
if store:
|
if store:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Cannot assign to this object")
|
||||||
"Cannot turn object into user variable")
|
newnode = value_to_ast(ival)
|
||||||
else:
|
|
||||||
newnode = value_to_ast(ival)
|
|
||||||
if newnode is None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Cannot represent inlined value")
|
|
||||||
return ast.copy_location(newnode, node)
|
return ast.copy_location(newnode, node)
|
||||||
|
|
||||||
def _resolve_attribute(self, node):
|
def _resolve_attribute(self, node):
|
||||||
|
@ -325,40 +322,23 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
class _ListReadOnlyParams(ast.NodeVisitor):
|
|
||||||
def visit_FunctionDef(self, node):
|
|
||||||
if hasattr(self, "read_only_params"):
|
|
||||||
raise ValueError("More than one function definition")
|
|
||||||
self.read_only_params = {arg.arg for arg in node.args.args}
|
|
||||||
self.generic_visit(node)
|
|
||||||
|
|
||||||
def visit_Name(self, node):
|
|
||||||
if isinstance(node.ctx, ast.Store):
|
|
||||||
try:
|
|
||||||
self.read_only_params.remove(node.id)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _list_read_only_params(func_def):
|
|
||||||
lrp = _ListReadOnlyParams()
|
|
||||||
lrp.visit(func_def)
|
|
||||||
return lrp.read_only_params
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
||||||
obj = k_args[0]
|
obj = k_args[0]
|
||||||
func_name = func_def.name
|
func_name = func_def.name
|
||||||
param_init = []
|
param_init = []
|
||||||
rop = _list_read_only_params(func_def)
|
|
||||||
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
||||||
arg_name = arg_ast.arg
|
if isinstance(arg_value, ast.AST):
|
||||||
if arg_name in rop:
|
value = arg_value
|
||||||
rm.register_replace(obj, func_name, arg_name, arg_value)
|
|
||||||
else:
|
else:
|
||||||
uservar = rm.resolve_name(obj, func_name, arg_name, True)
|
try:
|
||||||
|
value = value_to_ast(arg_value)
|
||||||
|
except NotASTRepresentable:
|
||||||
|
value = None
|
||||||
|
if value is None:
|
||||||
|
rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value)
|
||||||
|
else:
|
||||||
|
uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True)
|
||||||
target = ast.Name(uservar.name, ast.Store())
|
target = ast.Name(uservar.name, ast.Store())
|
||||||
value = value_to_ast(arg_value)
|
|
||||||
param_init.append(ast.Assign(targets=[target], value=value))
|
param_init.append(ast.Assign(targets=[target], value=value))
|
||||||
return param_init
|
return param_init
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,10 @@ def eval_ast(expr, symdict=dict()):
|
||||||
return eval(code, symdict)
|
return eval(code, symdict)
|
||||||
|
|
||||||
|
|
||||||
|
class NotASTRepresentable(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def value_to_ast(value):
|
def value_to_ast(value):
|
||||||
if isinstance(value, core_language.int64): # must be before int
|
if isinstance(value, core_language.int64): # must be before int
|
||||||
return ast.Call(
|
return ast.Call(
|
||||||
|
@ -41,7 +45,7 @@ def value_to_ast(value):
|
||||||
func=ast.Name("Quantity", ast.Load()),
|
func=ast.Name("Quantity", ast.Load()),
|
||||||
args=[value_to_ast(value.amount), ast.Str(value.unit)],
|
args=[value_to_ast(value.amount), ast.Str(value.unit)],
|
||||||
keywords=[], starargs=None, kwargs=None)
|
keywords=[], starargs=None, kwargs=None)
|
||||||
return None
|
raise NotASTRepresentable(str(value))
|
||||||
|
|
||||||
|
|
||||||
class NotConstant(Exception):
|
class NotConstant(Exception):
|
||||||
|
|
Loading…
Reference in New Issue