forked from M-Labs/artiq
150 lines
4.9 KiB
Python
150 lines
4.9 KiB
Python
import ast
|
|
from copy import copy, deepcopy
|
|
from collections import defaultdict
|
|
|
|
from artiq.transforms.tools import is_ref_transparent, count_all_nodes
|
|
|
|
|
|
class _TargetLister(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.targets = set()
|
|
|
|
def visit_Name(self, node):
|
|
if isinstance(node.ctx, ast.Store):
|
|
self.targets.add(node.id)
|
|
|
|
|
|
class _InterAssignRemover(ast.NodeTransformer):
|
|
def __init__(self):
|
|
self.replacements = dict()
|
|
self.modified_names = set()
|
|
# name -> set of names that depend on it
|
|
# i.e. when x is modified, dependencies[x] is the set of names that
|
|
# cannot be replaced anymore
|
|
self.dependencies = defaultdict(set)
|
|
|
|
def invalidate(self, name):
|
|
try:
|
|
del self.replacements[name]
|
|
except KeyError:
|
|
pass
|
|
for d in self.dependencies[name]:
|
|
self.invalidate(d)
|
|
del self.dependencies[name]
|
|
|
|
def visit_Name(self, node):
|
|
if isinstance(node.ctx, ast.Load):
|
|
try:
|
|
return deepcopy(self.replacements[node.id])
|
|
except KeyError:
|
|
return node
|
|
else:
|
|
self.modified_names.add(node.id)
|
|
self.invalidate(node.id)
|
|
return node
|
|
|
|
def visit_Assign(self, node):
|
|
node.value = self.visit(node.value)
|
|
node.targets = [self.visit(target) for target in node.targets]
|
|
rt, depends_on = is_ref_transparent(node.value)
|
|
if rt and count_all_nodes(node.value) < 100:
|
|
for target in node.targets:
|
|
if isinstance(target, ast.Name):
|
|
if target.id not in depends_on:
|
|
self.replacements[target.id] = node.value
|
|
for d in depends_on:
|
|
self.dependencies[d].add(target.id)
|
|
return node
|
|
|
|
def visit_AugAssign(self, node):
|
|
left = deepcopy(node.target)
|
|
left.ctx = ast.Load()
|
|
newnode = ast.copy_location(
|
|
ast.Assign(
|
|
targets=[node.target],
|
|
value=ast.BinOp(left=left, op=node.op, right=node.value)
|
|
),
|
|
node
|
|
)
|
|
return self.visit_Assign(newnode)
|
|
|
|
def modified_names_push(self):
|
|
prev_modified_names = self.modified_names
|
|
self.modified_names = set()
|
|
return prev_modified_names
|
|
|
|
def modified_names_pop(self, prev_modified_names):
|
|
for name in self.modified_names:
|
|
self.invalidate(name)
|
|
self.modified_names |= prev_modified_names
|
|
|
|
def visit_Try(self, node):
|
|
prev_modified_names = self.modified_names_push()
|
|
node.body = [self.visit(stmt) for stmt in node.body]
|
|
self.modified_names_pop(prev_modified_names)
|
|
|
|
prev_modified_names = self.modified_names_push()
|
|
prev_replacements = self.replacements
|
|
for handler in node.handlers:
|
|
self.replacements = copy(prev_replacements)
|
|
handler.body = [self.visit(stmt) for stmt in handler.body]
|
|
self.replacements = copy(prev_replacements)
|
|
node.orelse = [self.visit(stmt) for stmt in node.orelse]
|
|
self.modified_names_pop(prev_modified_names)
|
|
|
|
prev_modified_names = self.modified_names_push()
|
|
node.finalbody = [self.visit(stmt) for stmt in node.finalbody]
|
|
self.modified_names_pop(prev_modified_names)
|
|
return node
|
|
|
|
def visit_If(self, node):
|
|
node.test = self.visit(node.test)
|
|
|
|
prev_modified_names = self.modified_names_push()
|
|
|
|
prev_replacements = self.replacements
|
|
self.replacements = copy(prev_replacements)
|
|
node.body = [self.visit(n) for n in node.body]
|
|
self.replacements = copy(prev_replacements)
|
|
node.orelse = [self.visit(n) for n in node.orelse]
|
|
self.replacements = prev_replacements
|
|
|
|
self.modified_names_pop(prev_modified_names)
|
|
|
|
return node
|
|
|
|
def visit_loop(self, node):
|
|
prev_modified_names = self.modified_names_push()
|
|
prev_replacements = self.replacements
|
|
|
|
self.replacements = copy(prev_replacements)
|
|
tl = _TargetLister()
|
|
for n in node.body:
|
|
tl.visit(n)
|
|
for name in tl.targets:
|
|
self.invalidate(name)
|
|
node.body = [self.visit(n) for n in node.body]
|
|
|
|
self.replacements = copy(prev_replacements)
|
|
node.orelse = [self.visit(n) for n in node.orelse]
|
|
|
|
self.replacements = prev_replacements
|
|
self.modified_names_pop(prev_modified_names)
|
|
|
|
def visit_For(self, node):
|
|
prev_modified_names = self.modified_names_push()
|
|
node.target = self.visit(node.target)
|
|
self.modified_names_pop(prev_modified_names)
|
|
node.iter = self.visit(node.iter)
|
|
self.visit_loop(node)
|
|
return node
|
|
|
|
def visit_While(self, node):
|
|
self.visit_loop(node)
|
|
node.test = self.visit(node.test)
|
|
return node
|
|
|
|
|
|
def remove_inter_assigns(func_def):
|
|
_InterAssignRemover().visit(func_def)
|