forked from M-Labs/artiq
1
0
Fork 0

transforms/remove_inter_assign: support names and dependencies

This commit is contained in:
Sebastien Bourdeauducq 2014-11-03 11:35:54 +08:00
parent 217fe8251b
commit fba72cc0a2
3 changed files with 50 additions and 27 deletions

View File

@ -1,6 +1,6 @@
import ast import ast
from artiq.transforms.tools import is_replaceable from artiq.transforms.tools import is_ref_transparent
class _SourceLister(ast.NodeVisitor): class _SourceLister(ast.NodeVisitor):
@ -22,7 +22,7 @@ class _DeadCodeRemover(ast.NodeTransformer):
if (not isinstance(target, ast.Name) if (not isinstance(target, ast.Name)
or target.id in self.kept_targets): or target.id in self.kept_targets):
new_targets.append(target) new_targets.append(target)
if not new_targets and is_replaceable(node.value): if not new_targets and is_ref_transparent(node.value)[0]:
return None return None
else: else:
return node return node
@ -30,7 +30,7 @@ class _DeadCodeRemover(ast.NodeTransformer):
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
if (isinstance(node.target, ast.Name) if (isinstance(node.target, ast.Name)
and node.target.id not in self.kept_targets and node.target.id not in self.kept_targets
and is_replaceable(node.value)): and is_ref_transparent(node.value)[0]):
return None return None
else: else:
return node return node

View File

@ -1,7 +1,8 @@
import ast import ast
from copy import copy, deepcopy from copy import copy, deepcopy
from collections import defaultdict
from artiq.transforms.tools import is_replaceable from artiq.transforms.tools import is_ref_transparent
class _TargetLister(ast.NodeVisitor): class _TargetLister(ast.NodeVisitor):
@ -17,30 +18,42 @@ class _InterAssignRemover(ast.NodeTransformer):
def __init__(self): def __init__(self):
self.replacements = dict() self.replacements = dict()
self.modified_names = set() self.modified_names = set()
# name -> set of names that depend on it
# i.e. when x is modified, dependencies[x] is the set of names that
# cannot be replaced anymore
self.dependencies = defaultdict(set)
def invalidate(self, name):
try:
del self.replacements[name]
except KeyError:
pass
for d in self.dependencies[name]:
self.invalidate(d)
del self.dependencies[name]
def visit_Name(self, node): def visit_Name(self, node):
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
try: try:
return self.replacements[node.id] return deepcopy(self.replacements[node.id])
except KeyError: except KeyError:
return node return node
else: else:
self.modified_names.add(node.id) self.modified_names.add(node.id)
self.invalidate(node.id)
return node return node
def visit_Assign(self, node): def visit_Assign(self, node):
self.generic_visit(node) node.value = self.visit(node.value)
if is_replaceable(node.value): node.targets = [self.visit(target) for target in node.targets]
rt, depends_on = is_ref_transparent(node.value)
if rt:
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
self.replacements[target.id] = node.value if target.id not in depends_on:
else: self.replacements[target.id] = node.value
for target in node.targets: for d in depends_on:
if isinstance(target, ast.Name): self.dependencies[d].add(target.id)
try:
del self.replacements[target.id]
except KeyError:
pass
return node return node
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
@ -62,10 +75,7 @@ class _InterAssignRemover(ast.NodeTransformer):
def modified_names_pop(self, prev_modified_names): def modified_names_pop(self, prev_modified_names):
for name in self.modified_names: for name in self.modified_names:
try: self.invalidate(name)
del self.replacements[name]
except KeyError:
pass
self.modified_names |= prev_modified_names self.modified_names |= prev_modified_names
def visit_Try(self, node): def visit_Try(self, node):
@ -99,10 +109,7 @@ class _InterAssignRemover(ast.NodeTransformer):
for n in node.body: for n in node.body:
tl.visit(n) tl.visit(n)
for name in tl.targets: for name in tl.targets:
try: self.invalidate(name)
del self.replacements[name]
except KeyError:
pass
node.body = [self.visit(n) for n in node.body] node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements) self.replacements = copy(prev_replacements)

View File

@ -84,18 +84,34 @@ def eval_constant(node):
_replaceable_funcs = { _replaceable_funcs = {
"bool", "int", "float", "round", "bool", "int", "float", "round",
"int64", "round64", "Fraction", "int64", "round64", "Fraction",
"time_to_cycles", "cycles_to_time",
"Quantity" "Quantity"
} }
def is_replaceable(expr): def _is_ref_transparent(dependencies, expr):
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)): if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
return True return True
elif isinstance(expr, ast.Name):
dependencies.add(expr.id)
return True
elif isinstance(expr, ast.UnaryOp):
return _is_ref_transparent(dependencies, expr.operand)
elif isinstance(expr, ast.BinOp): elif isinstance(expr, ast.BinOp):
return is_replaceable(expr.left) and is_replaceable(expr.right) return (_is_ref_transparent(dependencies, expr.left)
and _is_ref_transparent(dependencies, expr.right))
elif isinstance(expr, ast.BoolOp): elif isinstance(expr, ast.BoolOp):
return all(is_replaceable(v) for v in expr.values) return all(_is_ref_transparent(dependencies, v) for v in expr.values)
elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name): elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name):
return expr.func.id in _replaceable_funcs return (expr.func.id in _replaceable_funcs and
all(_is_ref_transparent(dependencies, arg) for arg in expr.args))
else: else:
return False return False
def is_ref_transparent(expr):
dependencies = set()
if _is_ref_transparent(dependencies, expr):
return True, dependencies
else:
return False, None