forked from M-Labs/artiq
transforms: add dead code removal
This commit is contained in:
parent
be94a8b07c
commit
1e8c9837ac
@ -4,6 +4,7 @@ from artiq.transforms.inline import inline
|
||||
from artiq.transforms.lower_units import lower_units
|
||||
from artiq.transforms.remove_inter_assigns import remove_inter_assigns
|
||||
from artiq.transforms.fold_constants import fold_constants
|
||||
from artiq.transforms.remove_dead_code import remove_dead_code
|
||||
from artiq.transforms.unroll_loops import unroll_loops
|
||||
from artiq.transforms.interleave import interleave
|
||||
from artiq.transforms.lower_time import lower_time
|
||||
@ -43,7 +44,7 @@ class Core:
|
||||
|
||||
def run(self, k_function, k_args, k_kwargs):
|
||||
# transform/simplify AST
|
||||
_debug_unparse = _make_debug_unparse("fold_constants_2")
|
||||
_debug_unparse = _make_debug_unparse("remove_dead_code")
|
||||
|
||||
func_def, rpc_map, exception_map = inline(
|
||||
self, k_function, k_args, k_kwargs)
|
||||
@ -75,6 +76,9 @@ class Core:
|
||||
fold_constants(func_def)
|
||||
_debug_unparse("fold_constants_2", func_def)
|
||||
|
||||
remove_dead_code(func_def)
|
||||
_debug_unparse("remove_dead_code", func_def)
|
||||
|
||||
# compile to machine code and run
|
||||
binary = get_runtime_binary(self.runtime_env, func_def)
|
||||
self.core_com.load(binary)
|
||||
|
57
artiq/transforms/remove_dead_code.py
Normal file
57
artiq/transforms/remove_dead_code.py
Normal file
@ -0,0 +1,57 @@
|
||||
import ast
|
||||
|
||||
from artiq.transforms.tools import is_replaceable
|
||||
|
||||
|
||||
class _SourceLister(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.sources = set()
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.sources.add(node.id)
|
||||
|
||||
|
||||
class _DeadCodeRemover(ast.NodeTransformer):
|
||||
def __init__(self, kept_targets):
|
||||
self.kept_targets = kept_targets
|
||||
|
||||
def visit_Assign(self, node):
|
||||
new_targets = []
|
||||
for target in node.targets:
|
||||
if not (isinstance(target, ast.Name)
|
||||
and target.id not in self.kept_targets):
|
||||
new_targets.append(target)
|
||||
if not new_targets and is_replaceable(node.value):
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
if (isinstance(node.target, ast.Name)
|
||||
and node.target.id not in self.kept_targets
|
||||
and is_replaceable(node.value)):
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_If(self, node):
|
||||
if isinstance(node.test, ast.NameConstant):
|
||||
if node.test.value:
|
||||
return node.body
|
||||
else:
|
||||
return node.orelse
|
||||
else:
|
||||
return node
|
||||
|
||||
def visit_While(self, node):
|
||||
if isinstance(node.test, ast.NameConstant) and not node.test.value:
|
||||
return node.orelse
|
||||
else:
|
||||
return node
|
||||
|
||||
|
||||
def remove_dead_code(func_def):
|
||||
sl = _SourceLister()
|
||||
sl.visit(func_def)
|
||||
_DeadCodeRemover(sl.sources).visit(func_def)
|
@ -1,24 +1,7 @@
|
||||
import ast
|
||||
from copy import copy, deepcopy
|
||||
|
||||
|
||||
_replaceable_funcs = {
|
||||
"bool", "int", "float", "round",
|
||||
"int64", "round64", "Fraction",
|
||||
}
|
||||
|
||||
|
||||
def _is_replaceable(value):
|
||||
if isinstance(value, (ast.NameConstant, ast.Num, ast.Str)):
|
||||
return True
|
||||
elif isinstance(value, ast.BinOp):
|
||||
return _is_replaceable(value.left) and _is_replaceable(value.right)
|
||||
elif isinstance(value, ast.BoolOp):
|
||||
return all(_is_replaceable(v) for v in value.values)
|
||||
elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
|
||||
return value.func.id in _replaceable_funcs
|
||||
else:
|
||||
return False
|
||||
from artiq.transforms.tools import is_replaceable
|
||||
|
||||
|
||||
class _TargetLister(ast.NodeVisitor):
|
||||
@ -47,7 +30,7 @@ class _InterAssignRemover(ast.NodeTransformer):
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
if _is_replaceable(node.value):
|
||||
if is_replaceable(node.value):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.replacements[target.id] = node.value
|
||||
|
@ -75,3 +75,23 @@ def eval_constant(node):
|
||||
raise NotConstant
|
||||
else:
|
||||
raise NotConstant
|
||||
|
||||
|
||||
_replaceable_funcs = {
|
||||
"bool", "int", "float", "round",
|
||||
"int64", "round64", "Fraction",
|
||||
"Quantity"
|
||||
}
|
||||
|
||||
|
||||
def is_replaceable(expr):
|
||||
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
|
||||
return True
|
||||
elif isinstance(expr, ast.BinOp):
|
||||
return is_replaceable(expr.left) and is_replaceable(expr.right)
|
||||
elif isinstance(expr, ast.BoolOp):
|
||||
return all(is_replaceable(v) for v in expr.values)
|
||||
elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name):
|
||||
return expr.func.id in _replaceable_funcs
|
||||
else:
|
||||
return False
|
||||
|
Loading…
Reference in New Issue
Block a user