Ensure type comparisons see through type variables.

This commit is contained in:
whitequark 2015-07-16 14:59:05 +03:00
parent c1e7a82e97
commit 6cda67c0c6
1 changed files with 11 additions and 4 deletions

View File

@ -26,6 +26,13 @@ class UnificationError(Exception):
def __init__(self, typea, typeb): def __init__(self, typea, typeb):
self.typea, self.typeb = typea, typeb self.typea, self.typeb = typea, typeb
def _map_find(elts):
if isinstance(elts, list):
return [x.find() for x in elts]
elif isinstance(elts, dict):
return {k: elts[k].find() for k in elts}
else:
assert False
class Type(object): class Type(object):
pass pass
@ -113,7 +120,7 @@ class TMono(Type):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TMono) and \ return isinstance(other, TMono) and \
self.name == other.name and \ self.name == other.name and \
self.params == other.params _map_find(self.params) == _map_find(other.params)
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
@ -152,7 +159,7 @@ class TTuple(Type):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TTuple) and \ return isinstance(other, TTuple) and \
self.elts == other.elts _map_find(self.elts) == _map_find(other.elts)
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
@ -207,8 +214,8 @@ class TFunction(Type):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TFunction) and \ return isinstance(other, TFunction) and \
self.args == other.args and \ _map_find(self.args) == _map_find(other.args) and \
self.optargs == other.optargs _map_find(self.optargs) == _map_find(other.optargs)
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)