diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 8bc84e598..88ef3a151 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -135,6 +135,7 @@ class NamedValue(Value): def __init__(self, typ, name): super().__init__(typ) self.name, self.function = name, None + self.is_removed = False def set_name(self, new_name): if self.function is not None: @@ -235,7 +236,7 @@ class Instruction(User): self.drop_references() # Check this after drop_references in case this # is a self-referencing phi. - assert not any(self.uses) + assert all(use.is_removed for use in self.uses) def replace_with(self, value): self.replace_all_uses_with(value) @@ -370,7 +371,7 @@ class BasicBlock(NamedValue): self.remove_from_parent() # Check this after erasing instructions in case the block # loops into itself. - assert not any(self.uses) + assert all(use.is_removed for use in self.uses) def prepend(self, insn): assert isinstance(insn, Instruction) @@ -1360,14 +1361,6 @@ class LandingPad(Terminator): def cleanup(self): return self.operands[0] - def erase(self): - self.remove_from_parent() - # we should erase all clauses as well - for block in set(self.operands): - block.uses.remove(self) - block.erase() - assert not any(self.uses) - def clauses(self): return zip(self.operands[1:], self.types) diff --git a/artiq/compiler/transforms/dead_code_eliminator.py b/artiq/compiler/transforms/dead_code_eliminator.py index 608a46d55..f4862f2f9 100644 --- a/artiq/compiler/transforms/dead_code_eliminator.py +++ b/artiq/compiler/transforms/dead_code_eliminator.py @@ -15,13 +15,26 @@ class DeadCodeEliminator: self.process_function(func) def process_function(self, func): - modified = True - while modified: - modified = False - for block in list(func.basic_blocks): - if not any(block.predecessors()) and block != func.entry(): - self.remove_block(block) - modified = True + # defer removing those blocks, so our use checks will ignore deleted blocks + preserve = [func.entry()] + work_list = [func.entry()] + while any(work_list): + block = work_list.pop() + for succ in block.successors(): + if succ not in preserve: + preserve.append(succ) + work_list.append(succ) + + to_be_removed = [] + for block in func.basic_blocks: + if block not in preserve: + block.is_removed = True + to_be_removed.append(block) + for insn in block.instructions: + insn.is_removed = True + + for block in to_be_removed: + self.remove_block(block) modified = True while modified: @@ -42,6 +55,8 @@ class DeadCodeEliminator: def remove_block(self, block): # block.uses are updated while iterating for use in set(block.uses): + if use.is_removed: + continue if isinstance(use, ir.Phi): use.remove_incoming_block(block) if not any(use.operands): @@ -56,6 +71,8 @@ class DeadCodeEliminator: def remove_instruction(self, insn): for use in set(insn.uses): + if use.is_removed: + continue if isinstance(use, ir.Phi): use.remove_incoming_value(insn) if not any(use.operands):