forked from M-Labs/artiq
1
0
Fork 0

compiler: proper union find

The find implementation was not very optimized, and the unify function
did not consider tree height and may build some tall trees.
This commit is contained in:
pca006132 2021-07-07 09:20:41 +08:00 committed by Sébastien Bourdeauducq
parent 4ede58e44b
commit b10d1bdd37
1 changed files with 20 additions and 21 deletions

View File

@ -55,40 +55,39 @@ class TVar(Type):
def __init__(self):
self.parent = self
self.rank = 0
def find(self):
if self.parent is self:
parent = self.parent
if parent is self:
return self
else:
# The recursive find() invocation is turned into a loop
# because paths resulting from unification of large arrays
# can easily cause a stack overflow.
root = self
while root.__class__ == TVar:
if root is root.parent:
break
else:
root = root.parent
# path compression
iter = self
while iter.__class__ == TVar:
if iter is root:
break
else:
iter, iter.parent = iter.parent, root
return root
while parent.__class__ == TVar and root is not parent:
_, parent = root, root.parent = parent, parent.parent
return root.parent
def unify(self, other):
if other is self:
return
other = other.find()
if self.parent is self:
self.parent = other
x = other.find()
y = self.find()
if x is y:
return
if y.__class__ == TVar:
if x.__class__ == TVar:
if x.rank < y.rank:
x, y = y, x
y.parent = x
if x.rank == y.rank:
x.rank += 1
else:
self.find().unify(other)
y.parent = x
else:
y.unify(x)
def fold(self, accum, fn):
if self.parent is self: