diff --git a/artiq/transforms/remove_dead_code.py b/artiq/transforms/remove_dead_code.py index b176e6a4d..9a58c851d 100644 --- a/artiq/transforms/remove_dead_code.py +++ b/artiq/transforms/remove_dead_code.py @@ -1,6 +1,6 @@ import ast -from artiq.transforms.tools import is_replaceable +from artiq.transforms.tools import is_ref_transparent class _SourceLister(ast.NodeVisitor): @@ -22,7 +22,7 @@ class _DeadCodeRemover(ast.NodeTransformer): if (not isinstance(target, ast.Name) or target.id in self.kept_targets): new_targets.append(target) - if not new_targets and is_replaceable(node.value): + if not new_targets and is_ref_transparent(node.value)[0]: return None else: return node @@ -30,7 +30,7 @@ class _DeadCodeRemover(ast.NodeTransformer): def visit_AugAssign(self, node): if (isinstance(node.target, ast.Name) and node.target.id not in self.kept_targets - and is_replaceable(node.value)): + and is_ref_transparent(node.value)[0]): return None else: return node diff --git a/artiq/transforms/remove_inter_assigns.py b/artiq/transforms/remove_inter_assigns.py index 2e77339f8..bcc6e2fcd 100644 --- a/artiq/transforms/remove_inter_assigns.py +++ b/artiq/transforms/remove_inter_assigns.py @@ -1,7 +1,8 @@ import ast from copy import copy, deepcopy +from collections import defaultdict -from artiq.transforms.tools import is_replaceable +from artiq.transforms.tools import is_ref_transparent class _TargetLister(ast.NodeVisitor): @@ -17,30 +18,42 @@ class _InterAssignRemover(ast.NodeTransformer): def __init__(self): self.replacements = dict() self.modified_names = set() + # name -> set of names that depend on it + # i.e. when x is modified, dependencies[x] is the set of names that + # cannot be replaced anymore + self.dependencies = defaultdict(set) + + def invalidate(self, name): + try: + del self.replacements[name] + except KeyError: + pass + for d in self.dependencies[name]: + self.invalidate(d) + del self.dependencies[name] def visit_Name(self, node): if isinstance(node.ctx, ast.Load): try: - return self.replacements[node.id] + return deepcopy(self.replacements[node.id]) except KeyError: return node else: self.modified_names.add(node.id) + self.invalidate(node.id) return node def visit_Assign(self, node): - self.generic_visit(node) - if is_replaceable(node.value): + node.value = self.visit(node.value) + node.targets = [self.visit(target) for target in node.targets] + rt, depends_on = is_ref_transparent(node.value) + if rt: for target in node.targets: if isinstance(target, ast.Name): - self.replacements[target.id] = node.value - else: - for target in node.targets: - if isinstance(target, ast.Name): - try: - del self.replacements[target.id] - except KeyError: - pass + if target.id not in depends_on: + self.replacements[target.id] = node.value + for d in depends_on: + self.dependencies[d].add(target.id) return node def visit_AugAssign(self, node): @@ -62,10 +75,7 @@ class _InterAssignRemover(ast.NodeTransformer): def modified_names_pop(self, prev_modified_names): for name in self.modified_names: - try: - del self.replacements[name] - except KeyError: - pass + self.invalidate(name) self.modified_names |= prev_modified_names def visit_Try(self, node): @@ -99,10 +109,7 @@ class _InterAssignRemover(ast.NodeTransformer): for n in node.body: tl.visit(n) for name in tl.targets: - try: - del self.replacements[name] - except KeyError: - pass + self.invalidate(name) node.body = [self.visit(n) for n in node.body] self.replacements = copy(prev_replacements) diff --git a/artiq/transforms/tools.py b/artiq/transforms/tools.py index 8a7b4d08e..45a4bd2c0 100644 --- a/artiq/transforms/tools.py +++ b/artiq/transforms/tools.py @@ -84,18 +84,34 @@ def eval_constant(node): _replaceable_funcs = { "bool", "int", "float", "round", "int64", "round64", "Fraction", + "time_to_cycles", "cycles_to_time", "Quantity" } -def is_replaceable(expr): +def _is_ref_transparent(dependencies, expr): if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)): return True + elif isinstance(expr, ast.Name): + dependencies.add(expr.id) + return True + elif isinstance(expr, ast.UnaryOp): + return _is_ref_transparent(dependencies, expr.operand) elif isinstance(expr, ast.BinOp): - return is_replaceable(expr.left) and is_replaceable(expr.right) + return (_is_ref_transparent(dependencies, expr.left) + and _is_ref_transparent(dependencies, expr.right)) elif isinstance(expr, ast.BoolOp): - return all(is_replaceable(v) for v in expr.values) + return all(_is_ref_transparent(dependencies, v) for v in expr.values) elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name): - return expr.func.id in _replaceable_funcs + return (expr.func.id in _replaceable_funcs and + all(_is_ref_transparent(dependencies, arg) for arg in expr.args)) else: return False + + +def is_ref_transparent(expr): + dependencies = set() + if _is_ref_transparent(dependencies, expr): + return True, dependencies + else: + return False, None