forked from M-Labs/artiq
compiler: proper union find
The find implementation was not very optimized, and the unify function did not consider tree height and may build some tall trees.
This commit is contained in:
parent
4ede58e44b
commit
b10d1bdd37
|
@ -55,40 +55,39 @@ class TVar(Type):
|
|||
|
||||
def __init__(self):
|
||||
self.parent = self
|
||||
self.rank = 0
|
||||
|
||||
def find(self):
|
||||
if self.parent is self:
|
||||
parent = self.parent
|
||||
if parent is self:
|
||||
return self
|
||||
else:
|
||||
# The recursive find() invocation is turned into a loop
|
||||
# because paths resulting from unification of large arrays
|
||||
# can easily cause a stack overflow.
|
||||
root = self
|
||||
while root.__class__ == TVar:
|
||||
if root is root.parent:
|
||||
break
|
||||
else:
|
||||
root = root.parent
|
||||
|
||||
# path compression
|
||||
iter = self
|
||||
while iter.__class__ == TVar:
|
||||
if iter is root:
|
||||
break
|
||||
else:
|
||||
iter, iter.parent = iter.parent, root
|
||||
|
||||
return root
|
||||
while parent.__class__ == TVar and root is not parent:
|
||||
_, parent = root, root.parent = parent, parent.parent
|
||||
return root.parent
|
||||
|
||||
def unify(self, other):
|
||||
if other is self:
|
||||
return
|
||||
other = other.find()
|
||||
|
||||
if self.parent is self:
|
||||
self.parent = other
|
||||
x = other.find()
|
||||
y = self.find()
|
||||
if x is y:
|
||||
return
|
||||
if y.__class__ == TVar:
|
||||
if x.__class__ == TVar:
|
||||
if x.rank < y.rank:
|
||||
x, y = y, x
|
||||
y.parent = x
|
||||
if x.rank == y.rank:
|
||||
x.rank += 1
|
||||
else:
|
||||
y.parent = x
|
||||
else:
|
||||
self.find().unify(other)
|
||||
y.unify(x)
|
||||
|
||||
def fold(self, accum, fn):
|
||||
if self.parent is self:
|
||||
|
|
Loading…
Reference in New Issue