diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index 8bd96bf..9ef519f 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -125,11 +125,7 @@ class TVar(Type): if k not in x.fields: raise UnificationError( f'Cannot unify {y} with {x}') - if isinstance(v, TFunc) and not v.instantiated: - v = v.instantiate() u = x.fields[k] - if isinstance(u, TFunc) and not u.instantiated: - u = u.instantiate() v.unify(u) if isinstance(x, TList): if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]: @@ -162,63 +158,69 @@ class FuncArg: class TCall(Type): def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type): - self.posargs = posargs - self.kwargs = kwargs - self.ret = ret - self.fun = TVar() + self.calls = [[posargs, kwargs, ret, None]] + self.parent = self + self.rank = 0 def check(self): - self.fun.find().check() + self.calls[0][3].check() def find(self): - if isinstance(self.fun.find(), TVar): - return self - return self.fun.find() + root = self + parent = self.parent + while root is not parent and isinstance(parent, TCall): + _, parent = root, root.parent = parent, parent.parent + if parent.calls[0][3] is None: + return parent + return parent.calls[0][3] def unify(self, other): - if not isinstance(self.fun.find(), TVar): - self.fun.unify(other) + y = self.find() + if y is not self: + y.unify(other) return - other = other.find() - if other is self: + x = other.find() + if x is y: return - if isinstance(other, TCall): - for a, b in zip(self.posargs, other.posargs): - a.unify(b) - for k, v in self.kwargs.items(): - if k in other.kwargs: - other.kwargs[k].unify(v) - else: - other.kwargs[k] = v - for k, v in other.kwargs.items(): - if k not in self.kwargs: - self.kwargs[k] = v - self.fun.unify(other.fun) - elif isinstance(other, TFunc): - all_args = set(arg.name for arg in other.args) - required = set(arg.name for arg in other.args if not - arg.is_optional) - other.ret.unify(self.ret) - for i, v in enumerate(self.posargs): - arg = other.args[i] - arg.typ.unify(v) - if arg.name in required: - required.remove(arg.name) - for k, v in self.kwargs.items(): - arg = next((arg for arg in other.args if arg.name == k), None) - if arg is None: - raise UnificationError(f'Unknown kwarg {k}') - if k not in all_args: - raise UnificationError(f'Duplicated kwarg {k}') - arg.typ.unify(v) - if k in required: - required.remove(k) - all_args.remove(k) - if len(required) > 0: - raise UnificationError(f'Missing arguments') - self.fun.unify(other) + if isinstance(x, TCall): + # standard union find + if x.rank < y.rank: + x, y = y, x + y.parent = x + if x.rank == y.rank: + x.rank += 1 + x.calls += y.calls + elif isinstance(x, TFunc): + fn = x + for i in range(len(y.calls)): + posargs, kwargs, ret, _ = y.calls[i] + c = y.calls[i] + c[3] = fn + if not x.instantiated: + fn = x.instantiate() + all_args = set(arg.name for arg in fn.args) + required = set(arg.name for arg in fn.args if not + arg.is_optional) + fn.ret.unify(ret) + for i, v in enumerate(posargs): + arg = fn.args[i] + arg.typ.unify(v) + if arg.name in required: + required.remove(arg.name) + for k, v in kwargs.items(): + arg = next((arg for arg in fn.args if arg.name == k), None) + if arg is None: + raise UnificationError(f'Unknown kwarg {k}') + if k not in all_args: + raise UnificationError(f'Duplicated kwarg {k}') + arg.typ.unify(v) + if k in required: + required.remove(k) + all_args.remove(k) + if len(required) > 0: + raise UnificationError(f'Missing arguments') elif isinstance(other, TVar): other.unify(self) else: diff --git a/hm-inference/test.py b/hm-inference/test.py index 9763537..b8d49a9 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -5,7 +5,7 @@ from nac3_types import * var = TVar([TInt, TBool]) -var2 = TVar([TInt]) +var2 = TVar([TInt, TBool]) foo = TObj('Foo', { 'foo': TFunc([ @@ -26,14 +26,13 @@ for key, value in v.assignments.items(): print('-----------') src = """ -# a = Foo(1).foo(1, 2) -# b = Foo(1).foo(1, True) -# c = Foo(True).foo(True, 1) -# d = Foo(True).foo(True, True) +a = f.foo(1, 2) +b = f.foo(1, True) +c = g.foo(True, 1) +d = g.foo(True, True) -a = Foo(1).foo(1, 2) -c = y.foo(True, 1) -y = Foo(True) +f = Foo(1) +g = Foo(1) """ print(src)