forked from M-Labs/artiq
transforms/remove_inter_assign: support names and dependencies
This commit is contained in:
parent
217fe8251b
commit
fba72cc0a2
|
@ -1,6 +1,6 @@
|
|||
import ast
|
||||
|
||||
from artiq.transforms.tools import is_replaceable
|
||||
from artiq.transforms.tools import is_ref_transparent
|
||||
|
||||
|
||||
class _SourceLister(ast.NodeVisitor):
|
||||
|
@ -22,7 +22,7 @@ class _DeadCodeRemover(ast.NodeTransformer):
|
|||
if (not isinstance(target, ast.Name)
|
||||
or target.id in self.kept_targets):
|
||||
new_targets.append(target)
|
||||
if not new_targets and is_replaceable(node.value):
|
||||
if not new_targets and is_ref_transparent(node.value)[0]:
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
@ -30,7 +30,7 @@ class _DeadCodeRemover(ast.NodeTransformer):
|
|||
def visit_AugAssign(self, node):
|
||||
if (isinstance(node.target, ast.Name)
|
||||
and node.target.id not in self.kept_targets
|
||||
and is_replaceable(node.value)):
|
||||
and is_ref_transparent(node.value)[0]):
|
||||
return None
|
||||
else:
|
||||
return node
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import ast
|
||||
from copy import copy, deepcopy
|
||||
from collections import defaultdict
|
||||
|
||||
from artiq.transforms.tools import is_replaceable
|
||||
from artiq.transforms.tools import is_ref_transparent
|
||||
|
||||
|
||||
class _TargetLister(ast.NodeVisitor):
|
||||
|
@ -17,30 +18,42 @@ class _InterAssignRemover(ast.NodeTransformer):
|
|||
def __init__(self):
|
||||
self.replacements = dict()
|
||||
self.modified_names = set()
|
||||
# name -> set of names that depend on it
|
||||
# i.e. when x is modified, dependencies[x] is the set of names that
|
||||
# cannot be replaced anymore
|
||||
self.dependencies = defaultdict(set)
|
||||
|
||||
def invalidate(self, name):
|
||||
try:
|
||||
del self.replacements[name]
|
||||
except KeyError:
|
||||
pass
|
||||
for d in self.dependencies[name]:
|
||||
self.invalidate(d)
|
||||
del self.dependencies[name]
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
try:
|
||||
return self.replacements[node.id]
|
||||
return deepcopy(self.replacements[node.id])
|
||||
except KeyError:
|
||||
return node
|
||||
else:
|
||||
self.modified_names.add(node.id)
|
||||
self.invalidate(node.id)
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.generic_visit(node)
|
||||
if is_replaceable(node.value):
|
||||
node.value = self.visit(node.value)
|
||||
node.targets = [self.visit(target) for target in node.targets]
|
||||
rt, depends_on = is_ref_transparent(node.value)
|
||||
if rt:
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.replacements[target.id] = node.value
|
||||
else:
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
try:
|
||||
del self.replacements[target.id]
|
||||
except KeyError:
|
||||
pass
|
||||
if target.id not in depends_on:
|
||||
self.replacements[target.id] = node.value
|
||||
for d in depends_on:
|
||||
self.dependencies[d].add(target.id)
|
||||
return node
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
|
@ -62,10 +75,7 @@ class _InterAssignRemover(ast.NodeTransformer):
|
|||
|
||||
def modified_names_pop(self, prev_modified_names):
|
||||
for name in self.modified_names:
|
||||
try:
|
||||
del self.replacements[name]
|
||||
except KeyError:
|
||||
pass
|
||||
self.invalidate(name)
|
||||
self.modified_names |= prev_modified_names
|
||||
|
||||
def visit_Try(self, node):
|
||||
|
@ -99,10 +109,7 @@ class _InterAssignRemover(ast.NodeTransformer):
|
|||
for n in node.body:
|
||||
tl.visit(n)
|
||||
for name in tl.targets:
|
||||
try:
|
||||
del self.replacements[name]
|
||||
except KeyError:
|
||||
pass
|
||||
self.invalidate(name)
|
||||
node.body = [self.visit(n) for n in node.body]
|
||||
|
||||
self.replacements = copy(prev_replacements)
|
||||
|
|
|
@ -84,18 +84,34 @@ def eval_constant(node):
|
|||
_replaceable_funcs = {
|
||||
"bool", "int", "float", "round",
|
||||
"int64", "round64", "Fraction",
|
||||
"time_to_cycles", "cycles_to_time",
|
||||
"Quantity"
|
||||
}
|
||||
|
||||
|
||||
def is_replaceable(expr):
|
||||
def _is_ref_transparent(dependencies, expr):
|
||||
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
|
||||
return True
|
||||
elif isinstance(expr, ast.Name):
|
||||
dependencies.add(expr.id)
|
||||
return True
|
||||
elif isinstance(expr, ast.UnaryOp):
|
||||
return _is_ref_transparent(dependencies, expr.operand)
|
||||
elif isinstance(expr, ast.BinOp):
|
||||
return is_replaceable(expr.left) and is_replaceable(expr.right)
|
||||
return (_is_ref_transparent(dependencies, expr.left)
|
||||
and _is_ref_transparent(dependencies, expr.right))
|
||||
elif isinstance(expr, ast.BoolOp):
|
||||
return all(is_replaceable(v) for v in expr.values)
|
||||
return all(_is_ref_transparent(dependencies, v) for v in expr.values)
|
||||
elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name):
|
||||
return expr.func.id in _replaceable_funcs
|
||||
return (expr.func.id in _replaceable_funcs and
|
||||
all(_is_ref_transparent(dependencies, arg) for arg in expr.args))
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_ref_transparent(expr):
|
||||
dependencies = set()
|
||||
if _is_ref_transparent(dependencies, expr):
|
||||
return True, dependencies
|
||||
else:
|
||||
return False, None
|
||||
|
|
Loading…
Reference in New Issue