fixed type var check

This commit is contained in:
pca006132 2021-07-09 16:06:06 +08:00
parent 59628cfa38
commit 1c17ed003e
3 changed files with 22 additions and 11 deletions

View File

@ -57,7 +57,7 @@ class Visitor(ast.NodeVisitor):
{keyword.arg: keyword.value.type for keyword
in node.keywords}, node.type)
node.func.type.unify(call)
self.calls.append(node.func.type)
self.calls.append(call)
def visit_Attribute(self, node):
self.visit(node.value)
@ -158,3 +158,4 @@ class Visitor(ast.NodeVisitor):
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)

View File

@ -63,7 +63,7 @@ class TVar(Type):
if self.range is not None:
ty = self.find()
# maybe we should replace this with explicit eq
if ty not in self.range:
if ty is not self and ty not in self.range:
raise UnificationError(
f'{self.id} cannot be substituted by {ty}')
@ -125,10 +125,10 @@ class TVar(Type):
if k not in x.fields:
raise UnificationError(
f'Cannot unify {y} with {x}')
if isinstance(v, TFunc):
if isinstance(v, TFunc) and not v.instantiated:
v = v.instantiate()
u = x.fields[k]
if isinstance(u, TFunc):
if isinstance(u, TFunc) and not u.instantiated:
u = u.instantiate()
v.unify(u)
if isinstance(x, TList):
@ -168,8 +168,7 @@ class TCall(Type):
self.fun = TVar()
def check(self):
for arg in chain(self.posargs, self.kwargs.values()):
arg.check()
self.fun.find().check()
def find(self):
if isinstance(self.fun.find(), TVar):
@ -327,6 +326,9 @@ class TList(Type):
def __init__(self, param: Type):
self.param = param
def check(self):
self.param.check()
def unify(self, other):
other = other.find()
if isinstance(other, TVar):
@ -344,6 +346,10 @@ class TTuple(Type):
def __init__(self, params: List[Type]):
self.params = params
def check(self):
for p in self.params:
p.check()
def unify(self, other):
other = other.find()
if isinstance(other, TVar):

View File

@ -1,11 +1,11 @@
# from __future__ import annotations
from __future__ import annotations
import ast
from ast_visitor import Visitor
from nac3_types import *
var = TVar([TInt, TBool])
var2 = TVar()
var2 = TVar([TInt])
foo = TObj('Foo', {
'foo': TFunc([
@ -26,10 +26,14 @@ for key, value in v.assignments.items():
print('-----------')
src = """
# a = Foo(1).foo(1, 2)
# b = Foo(1).foo(1, True)
# c = Foo(True).foo(True, 1)
# d = Foo(True).foo(True, True)
a = Foo(1).foo(1, 2)
b = Foo(1).foo(1, True)
c = Foo(True).foo(True, 1)
d = Foo(True).foo(True, True)
c = y.foo(True, 1)
y = Foo(True)
"""
print(src)