transforms/remove_inter_assign: support names and dependencies

This commit is contained in:
Sebastien Bourdeauducq 2014-11-03 11:35:54 +08:00
parent 217fe8251b
commit fba72cc0a2
3 changed files with 50 additions and 27 deletions

View File

@ -1,6 +1,6 @@
import ast
from artiq.transforms.tools import is_replaceable
from artiq.transforms.tools import is_ref_transparent
class _SourceLister(ast.NodeVisitor):
@ -22,7 +22,7 @@ class _DeadCodeRemover(ast.NodeTransformer):
if (not isinstance(target, ast.Name)
or target.id in self.kept_targets):
new_targets.append(target)
if not new_targets and is_replaceable(node.value):
if not new_targets and is_ref_transparent(node.value)[0]:
return None
else:
return node
@ -30,7 +30,7 @@ class _DeadCodeRemover(ast.NodeTransformer):
def visit_AugAssign(self, node):
if (isinstance(node.target, ast.Name)
and node.target.id not in self.kept_targets
and is_replaceable(node.value)):
and is_ref_transparent(node.value)[0]):
return None
else:
return node

View File

@ -1,7 +1,8 @@
import ast
from copy import copy, deepcopy
from collections import defaultdict
from artiq.transforms.tools import is_replaceable
from artiq.transforms.tools import is_ref_transparent
class _TargetLister(ast.NodeVisitor):
@ -17,30 +18,42 @@ class _InterAssignRemover(ast.NodeTransformer):
def __init__(self):
self.replacements = dict()
self.modified_names = set()
# name -> set of names that depend on it
# i.e. when x is modified, dependencies[x] is the set of names that
# cannot be replaced anymore
self.dependencies = defaultdict(set)
def invalidate(self, name):
try:
del self.replacements[name]
except KeyError:
pass
for d in self.dependencies[name]:
self.invalidate(d)
del self.dependencies[name]
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
try:
return self.replacements[node.id]
return deepcopy(self.replacements[node.id])
except KeyError:
return node
else:
self.modified_names.add(node.id)
self.invalidate(node.id)
return node
def visit_Assign(self, node):
self.generic_visit(node)
if is_replaceable(node.value):
node.value = self.visit(node.value)
node.targets = [self.visit(target) for target in node.targets]
rt, depends_on = is_ref_transparent(node.value)
if rt:
for target in node.targets:
if isinstance(target, ast.Name):
if target.id not in depends_on:
self.replacements[target.id] = node.value
else:
for target in node.targets:
if isinstance(target, ast.Name):
try:
del self.replacements[target.id]
except KeyError:
pass
for d in depends_on:
self.dependencies[d].add(target.id)
return node
def visit_AugAssign(self, node):
@ -62,10 +75,7 @@ class _InterAssignRemover(ast.NodeTransformer):
def modified_names_pop(self, prev_modified_names):
for name in self.modified_names:
try:
del self.replacements[name]
except KeyError:
pass
self.invalidate(name)
self.modified_names |= prev_modified_names
def visit_Try(self, node):
@ -99,10 +109,7 @@ class _InterAssignRemover(ast.NodeTransformer):
for n in node.body:
tl.visit(n)
for name in tl.targets:
try:
del self.replacements[name]
except KeyError:
pass
self.invalidate(name)
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)

View File

@ -84,18 +84,34 @@ def eval_constant(node):
_replaceable_funcs = {
"bool", "int", "float", "round",
"int64", "round64", "Fraction",
"time_to_cycles", "cycles_to_time",
"Quantity"
}
def is_replaceable(expr):
def _is_ref_transparent(dependencies, expr):
if isinstance(expr, (ast.NameConstant, ast.Num, ast.Str)):
return True
elif isinstance(expr, ast.Name):
dependencies.add(expr.id)
return True
elif isinstance(expr, ast.UnaryOp):
return _is_ref_transparent(dependencies, expr.operand)
elif isinstance(expr, ast.BinOp):
return is_replaceable(expr.left) and is_replaceable(expr.right)
return (_is_ref_transparent(dependencies, expr.left)
and _is_ref_transparent(dependencies, expr.right))
elif isinstance(expr, ast.BoolOp):
return all(is_replaceable(v) for v in expr.values)
return all(_is_ref_transparent(dependencies, v) for v in expr.values)
elif isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name):
return expr.func.id in _replaceable_funcs
return (expr.func.id in _replaceable_funcs and
all(_is_ref_transparent(dependencies, arg) for arg in expr.args))
else:
return False
def is_ref_transparent(expr):
dependencies = set()
if _is_ref_transparent(dependencies, expr):
return True, dependencies
else:
return False, None