fixed type var check
This commit is contained in:
parent
59628cfa38
commit
1c17ed003e
|
@ -57,7 +57,7 @@ class Visitor(ast.NodeVisitor):
|
|||
{keyword.arg: keyword.value.type for keyword
|
||||
in node.keywords}, node.type)
|
||||
node.func.type.unify(call)
|
||||
self.calls.append(node.func.type)
|
||||
self.calls.append(call)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
self.visit(node.value)
|
||||
|
@ -158,3 +158,4 @@ class Visitor(ast.NodeVisitor):
|
|||
node.test.type.unify(TBool)
|
||||
for stmt in chain(node.body, node.orelse):
|
||||
self.visit(stmt)
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class TVar(Type):
|
|||
if self.range is not None:
|
||||
ty = self.find()
|
||||
# maybe we should replace this with explicit eq
|
||||
if ty not in self.range:
|
||||
if ty is not self and ty not in self.range:
|
||||
raise UnificationError(
|
||||
f'{self.id} cannot be substituted by {ty}')
|
||||
|
||||
|
@ -125,10 +125,10 @@ class TVar(Type):
|
|||
if k not in x.fields:
|
||||
raise UnificationError(
|
||||
f'Cannot unify {y} with {x}')
|
||||
if isinstance(v, TFunc):
|
||||
if isinstance(v, TFunc) and not v.instantiated:
|
||||
v = v.instantiate()
|
||||
u = x.fields[k]
|
||||
if isinstance(u, TFunc):
|
||||
if isinstance(u, TFunc) and not u.instantiated:
|
||||
u = u.instantiate()
|
||||
v.unify(u)
|
||||
if isinstance(x, TList):
|
||||
|
@ -168,8 +168,7 @@ class TCall(Type):
|
|||
self.fun = TVar()
|
||||
|
||||
def check(self):
|
||||
for arg in chain(self.posargs, self.kwargs.values()):
|
||||
arg.check()
|
||||
self.fun.find().check()
|
||||
|
||||
def find(self):
|
||||
if isinstance(self.fun.find(), TVar):
|
||||
|
@ -327,6 +326,9 @@ class TList(Type):
|
|||
def __init__(self, param: Type):
|
||||
self.param = param
|
||||
|
||||
def check(self):
|
||||
self.param.check()
|
||||
|
||||
def unify(self, other):
|
||||
other = other.find()
|
||||
if isinstance(other, TVar):
|
||||
|
@ -344,6 +346,10 @@ class TTuple(Type):
|
|||
def __init__(self, params: List[Type]):
|
||||
self.params = params
|
||||
|
||||
def check(self):
|
||||
for p in self.params:
|
||||
p.check()
|
||||
|
||||
def unify(self, other):
|
||||
other = other.find()
|
||||
if isinstance(other, TVar):
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# from __future__ import annotations
|
||||
from __future__ import annotations
|
||||
import ast
|
||||
from ast_visitor import Visitor
|
||||
from nac3_types import *
|
||||
|
||||
|
||||
var = TVar([TInt, TBool])
|
||||
var2 = TVar()
|
||||
var2 = TVar([TInt])
|
||||
|
||||
foo = TObj('Foo', {
|
||||
'foo': TFunc([
|
||||
|
@ -26,10 +26,14 @@ for key, value in v.assignments.items():
|
|||
print('-----------')
|
||||
|
||||
src = """
|
||||
# a = Foo(1).foo(1, 2)
|
||||
# b = Foo(1).foo(1, True)
|
||||
# c = Foo(True).foo(True, 1)
|
||||
# d = Foo(True).foo(True, True)
|
||||
|
||||
a = Foo(1).foo(1, 2)
|
||||
b = Foo(1).foo(1, True)
|
||||
c = Foo(True).foo(True, 1)
|
||||
d = Foo(True).foo(True, True)
|
||||
c = y.foo(True, 1)
|
||||
y = Foo(True)
|
||||
"""
|
||||
|
||||
print(src)
|
||||
|
|
Loading…
Reference in New Issue