forked from M-Labs/artiq
1
0
Fork 0

transforms: add dead code removal

This commit is contained in:
Sebastien Bourdeauducq 2014-10-29 20:23:58 +08:00
parent be94a8b07c
commit 1e8c9837ac
4 changed files with 84 additions and 20 deletions

View File

@ -4,6 +4,7 @@ from artiq.transforms.inline import inline
from artiq.transforms.lower_units import lower_units from artiq.transforms.lower_units import lower_units
from artiq.transforms.remove_inter_assigns import remove_inter_assigns from artiq.transforms.remove_inter_assigns import remove_inter_assigns
from artiq.transforms.fold_constants import fold_constants from artiq.transforms.fold_constants import fold_constants
from artiq.transforms.remove_dead_code import remove_dead_code
from artiq.transforms.unroll_loops import unroll_loops from artiq.transforms.unroll_loops import unroll_loops
from artiq.transforms.interleave import interleave from artiq.transforms.interleave import interleave
from artiq.transforms.lower_time import lower_time from artiq.transforms.lower_time import lower_time
@ -43,7 +44,7 @@ class Core:
def run(self, k_function, k_args, k_kwargs): def run(self, k_function, k_args, k_kwargs):
# transform/simplify AST # transform/simplify AST
_debug_unparse = _make_debug_unparse("fold_constants_2") _debug_unparse = _make_debug_unparse("remove_dead_code")
func_def, rpc_map, exception_map = inline( func_def, rpc_map, exception_map = inline(
self, k_function, k_args, k_kwargs) self, k_function, k_args, k_kwargs)
@ -75,6 +76,9 @@ class Core:
fold_constants(func_def) fold_constants(func_def)
_debug_unparse("fold_constants_2", func_def) _debug_unparse("fold_constants_2", func_def)
remove_dead_code(func_def)
_debug_unparse("remove_dead_code", func_def)
# compile to machine code and run # compile to machine code and run
binary = get_runtime_binary(self.runtime_env, func_def) binary = get_runtime_binary(self.runtime_env, func_def)
self.core_com.load(binary) self.core_com.load(binary)

View File

@ -0,0 +1,57 @@
import ast
from artiq.transforms.tools import is_replaceable
class _SourceLister(ast.NodeVisitor):
def __init__(self):
self.sources = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
self.sources.add(node.id)
class _DeadCodeRemover(ast.NodeTransformer):
def __init__(self, kept_targets):
self.kept_targets = kept_targets
def visit_Assign(self, node):
new_targets = []
for target in node.targets:
if not (isinstance(target, ast.Name)
and target.id not in self.kept_targets):
new_targets.append(target)
if not new_targets and is_replaceable(node.value):
return None
else:
return node
def visit_AugAssign(self, node):
if (isinstance(node.target, ast.Name)
and node.target.id not in self.kept_targets
and is_replaceable(node.value)):
return None
else:
return node
def visit_If(self, node):
if isinstance(node.test, ast.NameConstant):
if node.test.value:
return node.body
else:
return node.orelse
else:
return node
def visit_While(self, node):
if isinstance(node.test, ast.NameConstant) and not node.test.value:
return node.orelse
else:
return node
def remove_dead_code(func_def):
sl = _SourceLister()
sl.visit(func_def)
_DeadCodeRemover(sl.sources).visit(func_def)

View File

@ -1,24 +1,7 @@
import ast import ast
from copy import copy, deepcopy from copy import copy, deepcopy
from artiq.transforms.tools import is_replaceable
_replaceable_funcs = {
"bool", "int", "float", "round",
"int64", "round64", "Fraction",
}
def _is_replaceable(value):
if isinstance(value, (ast.NameConstant, ast.Num, ast.Str)):
return True
elif isinstance(value, ast.BinOp):
return _is_replaceable(value.left) and _is_replaceable(value.right)
elif isinstance(value, ast.BoolOp):
return all(_is_replaceable(v) for v in value.values)
elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
return value.func.id in _replaceable_funcs
else:
return False
class _TargetLister(ast.NodeVisitor): class _TargetLister(ast.NodeVisitor):
@ -47,7 +30,7 @@ class _InterAssignRemover(ast.NodeTransformer):
def visit_Assign(self, node): def visit_Assign(self, node):
self.generic_visit(node) self.generic_visit(node)
if _is_replaceable(node.value): if is_replaceable(node.value):
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
self.replacements[target.id] = node.value self.replacements[target.id] = node.value

View File

@ -75,3 +75,23 @@ def eval_constant(node):
raise NotConstant raise NotConstant
else: else:
raise NotConstant raise NotConstant
_replaceable_funcs = {
"bool", "int", "float", "round",
"int64", "round64", "Fraction",
"Quantity"
}
def is_replaceable(expr):
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
return True
elif isinstance(expr, ast.BinOp):
return is_replaceable(expr.left) and is_replaceable(expr.right)
elif isinstance(expr, ast.BoolOp):
return all(is_replaceable(v) for v in expr.values)
elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name):
return expr.func.id in _replaceable_funcs
else:
return False