diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 2b362f604..25fa268fb 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -4,6 +4,7 @@ from artiq.transforms.inline import inline from artiq.transforms.lower_units import lower_units from artiq.transforms.remove_inter_assigns import remove_inter_assigns from artiq.transforms.fold_constants import fold_constants +from artiq.transforms.remove_dead_code import remove_dead_code from artiq.transforms.unroll_loops import unroll_loops from artiq.transforms.interleave import interleave from artiq.transforms.lower_time import lower_time @@ -43,7 +44,7 @@ class Core: def run(self, k_function, k_args, k_kwargs): # transform/simplify AST - _debug_unparse = _make_debug_unparse("fold_constants_2") + _debug_unparse = _make_debug_unparse("remove_dead_code") func_def, rpc_map, exception_map = inline( self, k_function, k_args, k_kwargs) @@ -75,6 +76,9 @@ class Core: fold_constants(func_def) _debug_unparse("fold_constants_2", func_def) + remove_dead_code(func_def) + _debug_unparse("remove_dead_code", func_def) + # compile to machine code and run binary = get_runtime_binary(self.runtime_env, func_def) self.core_com.load(binary) diff --git a/artiq/transforms/remove_dead_code.py b/artiq/transforms/remove_dead_code.py new file mode 100644 index 000000000..26f873ced --- /dev/null +++ b/artiq/transforms/remove_dead_code.py @@ -0,0 +1,57 @@ +import ast + +from artiq.transforms.tools import is_replaceable + + +class _SourceLister(ast.NodeVisitor): + def __init__(self): + self.sources = set() + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load): + self.sources.add(node.id) + + +class _DeadCodeRemover(ast.NodeTransformer): + def __init__(self, kept_targets): + self.kept_targets = kept_targets + + def visit_Assign(self, node): + new_targets = [] + for target in node.targets: + if not (isinstance(target, ast.Name) + and target.id not in self.kept_targets): + new_targets.append(target) + if not new_targets and is_replaceable(node.value): + return None + else: + return node + + def visit_AugAssign(self, node): + if (isinstance(node.target, ast.Name) + and node.target.id not in self.kept_targets + and is_replaceable(node.value)): + return None + else: + return node + + def visit_If(self, node): + if isinstance(node.test, ast.NameConstant): + if node.test.value: + return node.body + else: + return node.orelse + else: + return node + + def visit_While(self, node): + if isinstance(node.test, ast.NameConstant) and not node.test.value: + return node.orelse + else: + return node + + +def remove_dead_code(func_def): + sl = _SourceLister() + sl.visit(func_def) + _DeadCodeRemover(sl.sources).visit(func_def) diff --git a/artiq/transforms/remove_inter_assigns.py b/artiq/transforms/remove_inter_assigns.py index 4479e116c..fe80a4e4a 100644 --- a/artiq/transforms/remove_inter_assigns.py +++ b/artiq/transforms/remove_inter_assigns.py @@ -1,24 +1,7 @@ import ast from copy import copy, deepcopy - -_replaceable_funcs = { - "bool", "int", "float", "round", - "int64", "round64", "Fraction", -} - - -def _is_replaceable(value): - if isinstance(value, (ast.NameConstant, ast.Num, ast.Str)): - return True - elif isinstance(value, ast.BinOp): - return _is_replaceable(value.left) and _is_replaceable(value.right) - elif isinstance(value, ast.BoolOp): - return all(_is_replaceable(v) for v in value.values) - elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name): - return value.func.id in _replaceable_funcs - else: - return False +from artiq.transforms.tools import is_replaceable class _TargetLister(ast.NodeVisitor): @@ -47,7 +30,7 @@ class _InterAssignRemover(ast.NodeTransformer): def visit_Assign(self, node): self.generic_visit(node) - if _is_replaceable(node.value): + if is_replaceable(node.value): for target in node.targets: if isinstance(target, ast.Name): self.replacements[target.id] = node.value diff --git a/artiq/transforms/tools.py b/artiq/transforms/tools.py index 2b3bbdbf3..e6a8baa7e 100644 --- a/artiq/transforms/tools.py +++ b/artiq/transforms/tools.py @@ -75,3 +75,23 @@ def eval_constant(node): raise NotConstant else: raise NotConstant + + +_replaceable_funcs = { + "bool", "int", "float", "round", + "int64", "round64", "Fraction", + "Quantity" +} + + +def is_replaceable(expr): + if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)): + return True + elif isinstance(expr, ast.BinOp): + return is_replaceable(expr.left) and is_replaceable(expr.right) + elif isinstance(expr, ast.BoolOp): + return all(is_replaceable(v) for v in expr.values) + elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name): + return expr.func.id in _replaceable_funcs + else: + return False