diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 0cdd1bd43..2b362f604 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -2,6 +2,7 @@ import os from artiq.transforms.inline import inline from artiq.transforms.lower_units import lower_units +from artiq.transforms.remove_inter_assigns import remove_inter_assigns from artiq.transforms.fold_constants import fold_constants from artiq.transforms.unroll_loops import unroll_loops from artiq.transforms.interleave import interleave @@ -51,6 +52,9 @@ class Core: lower_units(func_def, rpc_map) _debug_unparse("lower_units", func_def) + remove_inter_assigns(func_def) + _debug_unparse("remove_inter_assigns_1", func_def) + fold_constants(func_def) _debug_unparse("fold_constants_1", func_def) @@ -65,6 +69,9 @@ class Core: self.runtime_env.ref_period) _debug_unparse("lower_time", func_def) + remove_inter_assigns(func_def) + _debug_unparse("remove_inter_assigns_2", func_def) + fold_constants(func_def) _debug_unparse("fold_constants_2", func_def) diff --git a/artiq/transforms/remove_inter_assigns.py b/artiq/transforms/remove_inter_assigns.py new file mode 100644 index 000000000..4479e116c --- /dev/null +++ b/artiq/transforms/remove_inter_assigns.py @@ -0,0 +1,145 @@ +import ast +from copy import copy, deepcopy + + +_replaceable_funcs = { + "bool", "int", "float", "round", + "int64", "round64", "Fraction", +} + + +def _is_replaceable(value): + if isinstance(value, (ast.NameConstant, ast.Num, ast.Str)): + return True + elif isinstance(value, ast.BinOp): + return _is_replaceable(value.left) and _is_replaceable(value.right) + elif isinstance(value, ast.BoolOp): + return all(_is_replaceable(v) for v in value.values) + elif isinstance(value, ast.Call) and isinstance(value.func, ast.Name): + return value.func.id in _replaceable_funcs + else: + return False + + +class _TargetLister(ast.NodeVisitor): + def __init__(self): + self.targets = set() + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store): + self.targets.add(node.id) + + +class _InterAssignRemover(ast.NodeTransformer): + def __init__(self): + self.replacements = dict() + self.modified_names = set() + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load): + try: + return self.replacements[node.id] + except KeyError: + return node + else: + self.modified_names.add(node.id) + return node + + def visit_Assign(self, node): + self.generic_visit(node) + if _is_replaceable(node.value): + for target in node.targets: + if isinstance(target, ast.Name): + self.replacements[target.id] = node.value + else: + for target in node.targets: + try: + del self.replacements[target.id] + except KeyError: + pass + return node + + def visit_AugAssign(self, node): + left = deepcopy(node.target) + left.ctx = ast.Load() + newnode = ast.copy_location( + ast.Assign( + targets=[node.target], + value=ast.BinOp(left=left, op=node.op, right=node.value) + ), + node + ) + return self.visit_Assign(newnode) + + def modified_names_push(self): + prev_modified_names = self.modified_names + self.modified_names = set() + return prev_modified_names + + def modified_names_pop(self, prev_modified_names): + for name in self.modified_names: + try: + del self.replacements[name] + except KeyError: + pass + self.modified_names |= prev_modified_names + + def visit_Try(self, node): + prev_modified_names = self.modified_names_push() + self.generic_visit(node) + self.modified_names_pop(prev_modified_names) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + + prev_modified_names = self.modified_names_push() + + prev_replacements = self.replacements + self.replacements = copy(prev_replacements) + node.body = [self.visit(n) for n in node.body] + self.replacements = copy(prev_replacements) + node.orelse = [self.visit(n) for n in node.orelse] + self.replacements = prev_replacements + + self.modified_names_pop(prev_modified_names) + + return node + + def visit_loop(self, node): + prev_modified_names = self.modified_names_push() + prev_replacements = self.replacements + + self.replacements = copy(prev_replacements) + tl = _TargetLister() + for n in node.body: + tl.visit(n) + for name in tl.targets: + try: + del self.replacements[name] + except KeyError: + pass + node.body = [self.visit(n) for n in node.body] + + self.replacements = copy(prev_replacements) + node.orelse = [self.visit(n) for n in node.orelse] + + self.replacements = prev_replacements + self.modified_names_pop(prev_modified_names) + + def visit_For(self, node): + prev_modified_names = self.modified_names_push() + node.target = self.visit(node.target) + self.modified_names_pop(prev_modified_names) + node.iter = self.visit(node.iter) + self.visit_loop(node) + return node + + def visit_While(self, node): + self.visit_loop(node) + node.test = self.visit(node.test) + return node + + +def remove_inter_assigns(func_def): + _InterAssignRemover().visit(func_def)