compiler: proper union find

The find implementation was not very optimized, and the unify function
did not consider tree height and may build some tall trees.
This commit is contained in:
pca006132 2021-07-07 09:20:41 +08:00 committed by Sébastien Bourdeauducq
parent 4ede58e44b
commit b10d1bdd37
1 changed files with 20 additions and 21 deletions

View File

@ -55,40 +55,39 @@ class TVar(Type):
def __init__(self): def __init__(self):
self.parent = self self.parent = self
self.rank = 0
def find(self): def find(self):
if self.parent is self: parent = self.parent
if parent is self:
return self return self
else: else:
# The recursive find() invocation is turned into a loop # The recursive find() invocation is turned into a loop
# because paths resulting from unification of large arrays # because paths resulting from unification of large arrays
# can easily cause a stack overflow. # can easily cause a stack overflow.
root = self root = self
while root.__class__ == TVar: while parent.__class__ == TVar and root is not parent:
if root is root.parent: _, parent = root, root.parent = parent, parent.parent
break return root.parent
else:
root = root.parent
# path compression
iter = self
while iter.__class__ == TVar:
if iter is root:
break
else:
iter, iter.parent = iter.parent, root
return root
def unify(self, other): def unify(self, other):
if other is self: if other is self:
return return
other = other.find() x = other.find()
y = self.find()
if self.parent is self: if x is y:
self.parent = other return
if y.__class__ == TVar:
if x.__class__ == TVar:
if x.rank < y.rank:
x, y = y, x
y.parent = x
if x.rank == y.rank:
x.rank += 1
else: else:
self.find().unify(other) y.parent = x
else:
y.unify(x)
def fold(self, accum, fn): def fold(self, accum, fn):
if self.parent is self: if self.parent is self: