diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 5e69c399f..4b6dd9a55 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -26,6 +26,13 @@ class UnificationError(Exception): def __init__(self, typea, typeb): self.typea, self.typeb = typea, typeb +def _map_find(elts): + if isinstance(elts, list): + return [x.find() for x in elts] + elif isinstance(elts, dict): + return {k: elts[k].find() for k in elts} + else: + assert False class Type(object): pass @@ -113,7 +120,7 @@ class TMono(Type): def __eq__(self, other): return isinstance(other, TMono) and \ self.name == other.name and \ - self.params == other.params + _map_find(self.params) == _map_find(other.params) def __ne__(self, other): return not (self == other) @@ -152,7 +159,7 @@ class TTuple(Type): def __eq__(self, other): return isinstance(other, TTuple) and \ - self.elts == other.elts + _map_find(self.elts) == _map_find(other.elts) def __ne__(self, other): return not (self == other) @@ -207,8 +214,8 @@ class TFunction(Type): def __eq__(self, other): return isinstance(other, TFunction) and \ - self.args == other.args and \ - self.optargs == other.optargs + _map_find(self.args) == _map_find(other.args) and \ + _map_find(self.optargs) == _map_find(other.optargs) def __ne__(self, other): return not (self == other)