forked from M-Labs/artiq
1
0
Fork 0

transforms/inline: offload some work to remove_inter_assigns/remove_dead_code

This commit is contained in:
Sebastien Bourdeauducq 2014-10-30 19:13:01 +08:00
parent 1c0c0b691e
commit 97329b7fc9
2 changed files with 29 additions and 45 deletions

View File

@ -6,7 +6,7 @@ import ast
import builtins import builtins
from copy import deepcopy from copy import deepcopy
from artiq.transforms.tools import eval_ast, value_to_ast from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
@ -45,7 +45,7 @@ class _ReferenceManager:
self.kernel_attr_init = [] self.kernel_attr_init = []
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name) # (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
# -> _UserVariable(name) / ast / constant_object # -> _UserVariable(name) / complex object
self._to_inlined = dict() self._to_inlined = dict()
# inlined_name -> use_count # inlined_name -> use_count
self._use_count = dict() self._use_count = dict()
@ -57,10 +57,15 @@ class _ReferenceManager:
"range"): "range"):
self._use_count[name] = 1 self._use_count[name] = 1
# node_or_value can be a AST node, used to inline function parameter values # Complex objects in the namespace of functions can be used in two ways:
# that can be simplified later through constant folding. # 1. Calling a method on them (which gets inlined or RPCd)
def register_replace(self, obj, func_name, ref_name, node_or_value): # 2. Getting or setting attributes (which are turned into local variables)
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value # They are needed to implement "self", which is the only supported use
# case.
def register_complex_object(self, obj, func_name, ref_name,
complex_object):
assert(not isinstance(complex_object, ast.AST))
self._to_inlined[(id(obj), func_name, ref_name)] = complex_object
def new_name(self, base_name): def new_name(self, base_name):
if base_name[-1].isdigit(): if base_name[-1].isdigit():
@ -112,7 +117,7 @@ class _ReferenceManager:
def resolve_constant(self, obj, func_name, node): def resolve_constant(self, obj, func_name, node):
if isinstance(node, ast.Name): if isinstance(node, ast.Name):
c = self.resolve_name(obj, func_name, node.id, False) c = self.resolve_name(obj, func_name, node.id, False)
if isinstance(c, (_UserVariable, ast.AST)): if isinstance(c, _UserVariable):
raise ValueError("Not a constant") raise ValueError("Not a constant")
return c return c
elif isinstance(node, ast.Attribute): elif isinstance(node, ast.Attribute):
@ -192,18 +197,10 @@ class _ReferenceReplacer(ast.NodeVisitor):
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store) ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
if isinstance(ival, _UserVariable): if isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx) newnode = ast.Name(ival.name, node.ctx)
elif isinstance(ival, ast.AST):
assert(not store)
newnode = deepcopy(ival)
else: else:
if store: if store:
raise NotImplementedError( raise NotImplementedError("Cannot assign to this object")
"Cannot turn object into user variable")
else:
newnode = value_to_ast(ival) newnode = value_to_ast(ival)
if newnode is None:
raise NotImplementedError(
"Cannot represent inlined value")
return ast.copy_location(newnode, node) return ast.copy_location(newnode, node)
def _resolve_attribute(self, node): def _resolve_attribute(self, node):
@ -325,40 +322,23 @@ class _ReferenceReplacer(ast.NodeVisitor):
return node return node
class _ListReadOnlyParams(ast.NodeVisitor):
def visit_FunctionDef(self, node):
if hasattr(self, "read_only_params"):
raise ValueError("More than one function definition")
self.read_only_params = {arg.arg for arg in node.args.args}
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
try:
self.read_only_params.remove(node.id)
except KeyError:
pass
def _list_read_only_params(func_def):
lrp = _ListReadOnlyParams()
lrp.visit(func_def)
return lrp.read_only_params
def _initialize_function_params(func_def, k_args, k_kwargs, rm): def _initialize_function_params(func_def, k_args, k_kwargs, rm):
obj = k_args[0] obj = k_args[0]
func_name = func_def.name func_name = func_def.name
param_init = [] param_init = []
rop = _list_read_only_params(func_def)
for arg_ast, arg_value in zip(func_def.args.args, k_args): for arg_ast, arg_value in zip(func_def.args.args, k_args):
arg_name = arg_ast.arg if isinstance(arg_value, ast.AST):
if arg_name in rop: value = arg_value
rm.register_replace(obj, func_name, arg_name, arg_value)
else: else:
uservar = rm.resolve_name(obj, func_name, arg_name, True) try:
target = ast.Name(uservar.name, ast.Store())
value = value_to_ast(arg_value) value = value_to_ast(arg_value)
except NotASTRepresentable:
value = None
if value is None:
rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value)
else:
uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True)
target = ast.Name(uservar.name, ast.Store())
param_init.append(ast.Assign(targets=[target], value=value)) param_init.append(ast.Assign(targets=[target], value=value))
return param_init return param_init

View File

@ -13,6 +13,10 @@ def eval_ast(expr, symdict=dict()):
return eval(code, symdict) return eval(code, symdict)
class NotASTRepresentable(Exception):
pass
def value_to_ast(value): def value_to_ast(value):
if isinstance(value, core_language.int64): # must be before int if isinstance(value, core_language.int64): # must be before int
return ast.Call( return ast.Call(
@ -41,7 +45,7 @@ def value_to_ast(value):
func=ast.Name("Quantity", ast.Load()), func=ast.Name("Quantity", ast.Load()),
args=[value_to_ast(value.amount), ast.Str(value.unit)], args=[value_to_ast(value.amount), ast.Str(value.unit)],
keywords=[], starargs=None, kwargs=None) keywords=[], starargs=None, kwargs=None)
return None raise NotASTRepresentable(str(value))
class NotConstant(Exception): class NotConstant(Exception):