From b10d1bdd378a9a9f904b97c1434c36fcc619f112 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 7 Jul 2021 09:20:41 +0800 Subject: [PATCH] compiler: proper union find The find implementation was not very optimized, and the unify function did not consider tree height and may build some tall trees. --- artiq/compiler/types.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index e7b68a3a4..78364101b 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -55,40 +55,39 @@ class TVar(Type): def __init__(self): self.parent = self + self.rank = 0 def find(self): - if self.parent is self: + parent = self.parent + if parent is self: return self else: # The recursive find() invocation is turned into a loop # because paths resulting from unification of large arrays # can easily cause a stack overflow. root = self - while root.__class__ == TVar: - if root is root.parent: - break - else: - root = root.parent - - # path compression - iter = self - while iter.__class__ == TVar: - if iter is root: - break - else: - iter, iter.parent = iter.parent, root - - return root + while parent.__class__ == TVar and root is not parent: + _, parent = root, root.parent = parent, parent.parent + return root.parent def unify(self, other): if other is self: return - other = other.find() - - if self.parent is self: - self.parent = other + x = other.find() + y = self.find() + if x is y: + return + if y.__class__ == TVar: + if x.__class__ == TVar: + if x.rank < y.rank: + x, y = y, x + y.parent = x + if x.rank == y.rank: + x.rank += 1 + else: + y.parent = x else: - self.find().unify(other) + y.unify(x) def fold(self, accum, fn): if self.parent is self: