2
0
mirror of https://github.com/m-labs/artiq.git synced 2025-01-14 21:08:55 +08:00
artiq/artiq/transforms/remove_inter_assigns.py

149 lines
4.9 KiB
Python

import ast
from copy import copy, deepcopy
from collections import defaultdict
from artiq.transforms.tools import is_ref_transparent, count_all_nodes
class _TargetLister(ast.NodeVisitor):
def __init__(self):
self.targets = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self.targets.add(node.id)
class _InterAssignRemover(ast.NodeTransformer):
def __init__(self):
self.replacements = dict()
self.modified_names = set()
# name -> set of names that depend on it
# i.e. when x is modified, dependencies[x] is the set of names that
# cannot be replaced anymore
self.dependencies = defaultdict(set)
def invalidate(self, name):
try:
del self.replacements[name]
except KeyError:
pass
for d in self.dependencies[name]:
self.invalidate(d)
del self.dependencies[name]
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
try:
return deepcopy(self.replacements[node.id])
except KeyError:
return node
else:
self.modified_names.add(node.id)
self.invalidate(node.id)
return node
def visit_Assign(self, node):
node.value = self.visit(node.value)
node.targets = [self.visit(target) for target in node.targets]
rt, depends_on = is_ref_transparent(node.value)
if rt and count_all_nodes(node.value) < 100:
for target in node.targets:
if isinstance(target, ast.Name):
if target.id not in depends_on:
self.replacements[target.id] = node.value
for d in depends_on:
self.dependencies[d].add(target.id)
return node
def visit_AugAssign(self, node):
left = deepcopy(node.target)
left.ctx = ast.Load()
newnode = ast.copy_location(
ast.Assign(
targets=[node.target],
value=ast.BinOp(left=left, op=node.op, right=node.value)
),
node
)
return self.visit_Assign(newnode)
def modified_names_push(self):
prev_modified_names = self.modified_names
self.modified_names = set()
return prev_modified_names
def modified_names_pop(self, prev_modified_names):
for name in self.modified_names:
self.invalidate(name)
self.modified_names |= prev_modified_names
def visit_Try(self, node):
prev_modified_names = self.modified_names_push()
node.body = [self.visit(stmt) for stmt in node.body]
self.modified_names_pop(prev_modified_names)
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
for handler in node.handlers:
self.replacements = copy(prev_replacements)
handler.body = [self.visit(stmt) for stmt in handler.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(stmt) for stmt in node.orelse]
self.modified_names_pop(prev_modified_names)
prev_modified_names = self.modified_names_push()
node.finalbody = [self.visit(stmt) for stmt in node.finalbody]
self.modified_names_pop(prev_modified_names)
return node
def visit_If(self, node):
node.test = self.visit(node.test)
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
self.replacements = copy(prev_replacements)
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(n) for n in node.orelse]
self.replacements = prev_replacements
self.modified_names_pop(prev_modified_names)
return node
def visit_loop(self, node):
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
self.replacements = copy(prev_replacements)
tl = _TargetLister()
for n in node.body:
tl.visit(n)
for name in tl.targets:
self.invalidate(name)
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(n) for n in node.orelse]
self.replacements = prev_replacements
self.modified_names_pop(prev_modified_names)
def visit_For(self, node):
prev_modified_names = self.modified_names_push()
node.target = self.visit(node.target)
self.modified_names_pop(prev_modified_names)
node.iter = self.visit(node.iter)
self.visit_loop(node)
return node
def visit_While(self, node):
self.visit_loop(node)
node.test = self.visit(node.test)
return node
def remove_inter_assigns(func_def):
_InterAssignRemover().visit(func_def)