diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py index b1b6cad..bd453b1 100644 --- a/hm-inference/ast_visitor.py +++ b/hm-inference/ast_visitor.py @@ -57,7 +57,7 @@ class Visitor(ast.NodeVisitor): {keyword.arg: keyword.value.type for keyword in node.keywords}, node.type) node.func.type.unify(call) - self.calls.append(node.func.type) + self.calls.append(call) def visit_Attribute(self, node): self.visit(node.value) @@ -158,3 +158,4 @@ class Visitor(ast.NodeVisitor): node.test.type.unify(TBool) for stmt in chain(node.body, node.orelse): self.visit(stmt) + diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index b176b78..8bd96bf 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -63,7 +63,7 @@ class TVar(Type): if self.range is not None: ty = self.find() # maybe we should replace this with explicit eq - if ty not in self.range: + if ty is not self and ty not in self.range: raise UnificationError( f'{self.id} cannot be substituted by {ty}') @@ -125,10 +125,10 @@ class TVar(Type): if k not in x.fields: raise UnificationError( f'Cannot unify {y} with {x}') - if isinstance(v, TFunc): + if isinstance(v, TFunc) and not v.instantiated: v = v.instantiate() u = x.fields[k] - if isinstance(u, TFunc): + if isinstance(u, TFunc) and not u.instantiated: u = u.instantiate() v.unify(u) if isinstance(x, TList): @@ -168,8 +168,7 @@ class TCall(Type): self.fun = TVar() def check(self): - for arg in chain(self.posargs, self.kwargs.values()): - arg.check() + self.fun.find().check() def find(self): if isinstance(self.fun.find(), TVar): @@ -327,6 +326,9 @@ class TList(Type): def __init__(self, param: Type): self.param = param + def check(self): + self.param.check() + def unify(self, other): other = other.find() if isinstance(other, TVar): @@ -344,6 +346,10 @@ class TTuple(Type): def __init__(self, params: List[Type]): self.params = params + def check(self): + for p in self.params: + p.check() + def unify(self, other): other = other.find() if isinstance(other, TVar): diff --git a/hm-inference/test.py b/hm-inference/test.py index 0e6a75b..9763537 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -1,11 +1,11 @@ -# from __future__ import annotations +from __future__ import annotations import ast from ast_visitor import Visitor from nac3_types import * var = TVar([TInt, TBool]) -var2 = TVar() +var2 = TVar([TInt]) foo = TObj('Foo', { 'foo': TFunc([ @@ -26,10 +26,14 @@ for key, value in v.assignments.items(): print('-----------') src = """ +# a = Foo(1).foo(1, 2) +# b = Foo(1).foo(1, True) +# c = Foo(True).foo(True, 1) +# d = Foo(True).foo(True, True) + a = Foo(1).foo(1, 2) -b = Foo(1).foo(1, True) -c = Foo(True).foo(True, 1) -d = Foo(True).foo(True, True) +c = y.foo(True, 1) +y = Foo(True) """ print(src)