forked from M-Labs/artiq
compiler: proper union find
The find implementation was not very optimized, and the unify function did not consider tree height and may build some tall trees.
This commit is contained in:
parent
4ede58e44b
commit
b10d1bdd37
|
@ -55,40 +55,39 @@ class TVar(Type):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parent = self
|
self.parent = self
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
def find(self):
|
def find(self):
|
||||||
if self.parent is self:
|
parent = self.parent
|
||||||
|
if parent is self:
|
||||||
return self
|
return self
|
||||||
else:
|
else:
|
||||||
# The recursive find() invocation is turned into a loop
|
# The recursive find() invocation is turned into a loop
|
||||||
# because paths resulting from unification of large arrays
|
# because paths resulting from unification of large arrays
|
||||||
# can easily cause a stack overflow.
|
# can easily cause a stack overflow.
|
||||||
root = self
|
root = self
|
||||||
while root.__class__ == TVar:
|
while parent.__class__ == TVar and root is not parent:
|
||||||
if root is root.parent:
|
_, parent = root, root.parent = parent, parent.parent
|
||||||
break
|
return root.parent
|
||||||
else:
|
|
||||||
root = root.parent
|
|
||||||
|
|
||||||
# path compression
|
|
||||||
iter = self
|
|
||||||
while iter.__class__ == TVar:
|
|
||||||
if iter is root:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
iter, iter.parent = iter.parent, root
|
|
||||||
|
|
||||||
return root
|
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
if other is self:
|
if other is self:
|
||||||
return
|
return
|
||||||
other = other.find()
|
x = other.find()
|
||||||
|
y = self.find()
|
||||||
if self.parent is self:
|
if x is y:
|
||||||
self.parent = other
|
return
|
||||||
|
if y.__class__ == TVar:
|
||||||
|
if x.__class__ == TVar:
|
||||||
|
if x.rank < y.rank:
|
||||||
|
x, y = y, x
|
||||||
|
y.parent = x
|
||||||
|
if x.rank == y.rank:
|
||||||
|
x.rank += 1
|
||||||
else:
|
else:
|
||||||
self.find().unify(other)
|
y.parent = x
|
||||||
|
else:
|
||||||
|
y.unify(x)
|
||||||
|
|
||||||
def fold(self, accum, fn):
|
def fold(self, accum, fn):
|
||||||
if self.parent is self:
|
if self.parent is self:
|
||||||
|
|
Loading…
Reference in New Issue