transform: add intermediate assignment removal

This commit is contained in:
Sebastien Bourdeauducq 2014-10-29 17:09:45 +08:00
parent f012151506
commit 70cc0d1766
2 changed files with 152 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import os
from artiq.transforms.inline import inline from artiq.transforms.inline import inline
from artiq.transforms.lower_units import lower_units from artiq.transforms.lower_units import lower_units
from artiq.transforms.remove_inter_assigns import remove_inter_assigns
from artiq.transforms.fold_constants import fold_constants from artiq.transforms.fold_constants import fold_constants
from artiq.transforms.unroll_loops import unroll_loops from artiq.transforms.unroll_loops import unroll_loops
from artiq.transforms.interleave import interleave from artiq.transforms.interleave import interleave
@ -51,6 +52,9 @@ class Core:
lower_units(func_def, rpc_map) lower_units(func_def, rpc_map)
_debug_unparse("lower_units", func_def) _debug_unparse("lower_units", func_def)
remove_inter_assigns(func_def)
_debug_unparse("remove_inter_assigns_1", func_def)
fold_constants(func_def) fold_constants(func_def)
_debug_unparse("fold_constants_1", func_def) _debug_unparse("fold_constants_1", func_def)
@ -65,6 +69,9 @@ class Core:
self.runtime_env.ref_period) self.runtime_env.ref_period)
_debug_unparse("lower_time", func_def) _debug_unparse("lower_time", func_def)
remove_inter_assigns(func_def)
_debug_unparse("remove_inter_assigns_2", func_def)
fold_constants(func_def) fold_constants(func_def)
_debug_unparse("fold_constants_2", func_def) _debug_unparse("fold_constants_2", func_def)

View File

@ -0,0 +1,145 @@
import ast
from copy import copy, deepcopy
_replaceable_funcs = {
"bool", "int", "float", "round",
"int64", "round64", "Fraction",
}
def _is_replaceable(value):
if isinstance(value, (ast.NameConstant, ast.Num, ast.Str)):
return True
elif isinstance(value, ast.BinOp):
return _is_replaceable(value.left) and _is_replaceable(value.right)
elif isinstance(value, ast.BoolOp):
return all(_is_replaceable(v) for v in value.values)
elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
return value.func.id in _replaceable_funcs
else:
return False
class _TargetLister(ast.NodeVisitor):
def __init__(self):
self.targets = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self.targets.add(node.id)
class _InterAssignRemover(ast.NodeTransformer):
def __init__(self):
self.replacements = dict()
self.modified_names = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
try:
return self.replacements[node.id]
except KeyError:
return node
else:
self.modified_names.add(node.id)
return node
def visit_Assign(self, node):
self.generic_visit(node)
if _is_replaceable(node.value):
for target in node.targets:
if isinstance(target, ast.Name):
self.replacements[target.id] = node.value
else:
for target in node.targets:
try:
del self.replacements[target.id]
except KeyError:
pass
return node
def visit_AugAssign(self, node):
left = deepcopy(node.target)
left.ctx = ast.Load()
newnode = ast.copy_location(
ast.Assign(
targets=[node.target],
value=ast.BinOp(left=left, op=node.op, right=node.value)
),
node
)
return self.visit_Assign(newnode)
def modified_names_push(self):
prev_modified_names = self.modified_names
self.modified_names = set()
return prev_modified_names
def modified_names_pop(self, prev_modified_names):
for name in self.modified_names:
try:
del self.replacements[name]
except KeyError:
pass
self.modified_names |= prev_modified_names
def visit_Try(self, node):
prev_modified_names = self.modified_names_push()
self.generic_visit(node)
self.modified_names_pop(prev_modified_names)
return node
def visit_If(self, node):
node.test = self.visit(node.test)
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
self.replacements = copy(prev_replacements)
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(n) for n in node.orelse]
self.replacements = prev_replacements
self.modified_names_pop(prev_modified_names)
return node
def visit_loop(self, node):
prev_modified_names = self.modified_names_push()
prev_replacements = self.replacements
self.replacements = copy(prev_replacements)
tl = _TargetLister()
for n in node.body:
tl.visit(n)
for name in tl.targets:
try:
del self.replacements[name]
except KeyError:
pass
node.body = [self.visit(n) for n in node.body]
self.replacements = copy(prev_replacements)
node.orelse = [self.visit(n) for n in node.orelse]
self.replacements = prev_replacements
self.modified_names_pop(prev_modified_names)
def visit_For(self, node):
prev_modified_names = self.modified_names_push()
node.target = self.visit(node.target)
self.modified_names_pop(prev_modified_names)
node.iter = self.visit(node.iter)
self.visit_loop(node)
return node
def visit_While(self, node):
self.visit_loop(node)
node.test = self.visit(node.test)
return node
def remove_inter_assigns(func_def):
_InterAssignRemover().visit(func_def)