fixed type var check

This commit is contained in:
pca006132 2021-07-09 16:06:06 +08:00
parent 59628cfa38
commit 1c17ed003e
3 changed files with 22 additions and 11 deletions

View File

@ -57,7 +57,7 @@ class Visitor(ast.NodeVisitor):
{keyword.arg: keyword.value.type for keyword {keyword.arg: keyword.value.type for keyword
in node.keywords}, node.type) in node.keywords}, node.type)
node.func.type.unify(call) node.func.type.unify(call)
self.calls.append(node.func.type) self.calls.append(call)
def visit_Attribute(self, node): def visit_Attribute(self, node):
self.visit(node.value) self.visit(node.value)
@ -158,3 +158,4 @@ class Visitor(ast.NodeVisitor):
node.test.type.unify(TBool) node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse): for stmt in chain(node.body, node.orelse):
self.visit(stmt) self.visit(stmt)

View File

@ -63,7 +63,7 @@ class TVar(Type):
if self.range is not None: if self.range is not None:
ty = self.find() ty = self.find()
# maybe we should replace this with explicit eq # maybe we should replace this with explicit eq
if ty not in self.range: if ty is not self and ty not in self.range:
raise UnificationError( raise UnificationError(
f'{self.id} cannot be substituted by {ty}') f'{self.id} cannot be substituted by {ty}')
@ -125,10 +125,10 @@ class TVar(Type):
if k not in x.fields: if k not in x.fields:
raise UnificationError( raise UnificationError(
f'Cannot unify {y} with {x}') f'Cannot unify {y} with {x}')
if isinstance(v, TFunc): if isinstance(v, TFunc) and not v.instantiated:
v = v.instantiate() v = v.instantiate()
u = x.fields[k] u = x.fields[k]
if isinstance(u, TFunc): if isinstance(u, TFunc) and not u.instantiated:
u = u.instantiate() u = u.instantiate()
v.unify(u) v.unify(u)
if isinstance(x, TList): if isinstance(x, TList):
@ -168,8 +168,7 @@ class TCall(Type):
self.fun = TVar() self.fun = TVar()
def check(self): def check(self):
for arg in chain(self.posargs, self.kwargs.values()): self.fun.find().check()
arg.check()
def find(self): def find(self):
if isinstance(self.fun.find(), TVar): if isinstance(self.fun.find(), TVar):
@ -327,6 +326,9 @@ class TList(Type):
def __init__(self, param: Type): def __init__(self, param: Type):
self.param = param self.param = param
def check(self):
self.param.check()
def unify(self, other): def unify(self, other):
other = other.find() other = other.find()
if isinstance(other, TVar): if isinstance(other, TVar):
@ -344,6 +346,10 @@ class TTuple(Type):
def __init__(self, params: List[Type]): def __init__(self, params: List[Type]):
self.params = params self.params = params
def check(self):
for p in self.params:
p.check()
def unify(self, other): def unify(self, other):
other = other.find() other = other.find()
if isinstance(other, TVar): if isinstance(other, TVar):

View File

@ -1,11 +1,11 @@
# from __future__ import annotations from __future__ import annotations
import ast import ast
from ast_visitor import Visitor from ast_visitor import Visitor
from nac3_types import * from nac3_types import *
var = TVar([TInt, TBool]) var = TVar([TInt, TBool])
var2 = TVar() var2 = TVar([TInt])
foo = TObj('Foo', { foo = TObj('Foo', {
'foo': TFunc([ 'foo': TFunc([
@ -26,10 +26,14 @@ for key, value in v.assignments.items():
print('-----------') print('-----------')
src = """ src = """
# a = Foo(1).foo(1, 2)
# b = Foo(1).foo(1, True)
# c = Foo(True).foo(True, 1)
# d = Foo(True).foo(True, True)
a = Foo(1).foo(1, 2) a = Foo(1).foo(1, 2)
b = Foo(1).foo(1, True) c = y.foo(True, 1)
c = Foo(True).foo(True, 1) y = Foo(True)
d = Foo(True).foo(True, True)
""" """
print(src) print(src)