fixed type var check
This commit is contained in:
parent
59628cfa38
commit
1c17ed003e
@ -57,7 +57,7 @@ class Visitor(ast.NodeVisitor):
|
|||||||
{keyword.arg: keyword.value.type for keyword
|
{keyword.arg: keyword.value.type for keyword
|
||||||
in node.keywords}, node.type)
|
in node.keywords}, node.type)
|
||||||
node.func.type.unify(call)
|
node.func.type.unify(call)
|
||||||
self.calls.append(node.func.type)
|
self.calls.append(call)
|
||||||
|
|
||||||
def visit_Attribute(self, node):
|
def visit_Attribute(self, node):
|
||||||
self.visit(node.value)
|
self.visit(node.value)
|
||||||
@ -158,3 +158,4 @@ class Visitor(ast.NodeVisitor):
|
|||||||
node.test.type.unify(TBool)
|
node.test.type.unify(TBool)
|
||||||
for stmt in chain(node.body, node.orelse):
|
for stmt in chain(node.body, node.orelse):
|
||||||
self.visit(stmt)
|
self.visit(stmt)
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class TVar(Type):
|
|||||||
if self.range is not None:
|
if self.range is not None:
|
||||||
ty = self.find()
|
ty = self.find()
|
||||||
# maybe we should replace this with explicit eq
|
# maybe we should replace this with explicit eq
|
||||||
if ty not in self.range:
|
if ty is not self and ty not in self.range:
|
||||||
raise UnificationError(
|
raise UnificationError(
|
||||||
f'{self.id} cannot be substituted by {ty}')
|
f'{self.id} cannot be substituted by {ty}')
|
||||||
|
|
||||||
@ -125,10 +125,10 @@ class TVar(Type):
|
|||||||
if k not in x.fields:
|
if k not in x.fields:
|
||||||
raise UnificationError(
|
raise UnificationError(
|
||||||
f'Cannot unify {y} with {x}')
|
f'Cannot unify {y} with {x}')
|
||||||
if isinstance(v, TFunc):
|
if isinstance(v, TFunc) and not v.instantiated:
|
||||||
v = v.instantiate()
|
v = v.instantiate()
|
||||||
u = x.fields[k]
|
u = x.fields[k]
|
||||||
if isinstance(u, TFunc):
|
if isinstance(u, TFunc) and not u.instantiated:
|
||||||
u = u.instantiate()
|
u = u.instantiate()
|
||||||
v.unify(u)
|
v.unify(u)
|
||||||
if isinstance(x, TList):
|
if isinstance(x, TList):
|
||||||
@ -168,8 +168,7 @@ class TCall(Type):
|
|||||||
self.fun = TVar()
|
self.fun = TVar()
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
for arg in chain(self.posargs, self.kwargs.values()):
|
self.fun.find().check()
|
||||||
arg.check()
|
|
||||||
|
|
||||||
def find(self):
|
def find(self):
|
||||||
if isinstance(self.fun.find(), TVar):
|
if isinstance(self.fun.find(), TVar):
|
||||||
@ -327,6 +326,9 @@ class TList(Type):
|
|||||||
def __init__(self, param: Type):
|
def __init__(self, param: Type):
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
self.param.check()
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
other = other.find()
|
other = other.find()
|
||||||
if isinstance(other, TVar):
|
if isinstance(other, TVar):
|
||||||
@ -344,6 +346,10 @@ class TTuple(Type):
|
|||||||
def __init__(self, params: List[Type]):
|
def __init__(self, params: List[Type]):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
for p in self.params:
|
||||||
|
p.check()
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
other = other.find()
|
other = other.find()
|
||||||
if isinstance(other, TVar):
|
if isinstance(other, TVar):
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
# from __future__ import annotations
|
from __future__ import annotations
|
||||||
import ast
|
import ast
|
||||||
from ast_visitor import Visitor
|
from ast_visitor import Visitor
|
||||||
from nac3_types import *
|
from nac3_types import *
|
||||||
|
|
||||||
|
|
||||||
var = TVar([TInt, TBool])
|
var = TVar([TInt, TBool])
|
||||||
var2 = TVar()
|
var2 = TVar([TInt])
|
||||||
|
|
||||||
foo = TObj('Foo', {
|
foo = TObj('Foo', {
|
||||||
'foo': TFunc([
|
'foo': TFunc([
|
||||||
@ -26,10 +26,14 @@ for key, value in v.assignments.items():
|
|||||||
print('-----------')
|
print('-----------')
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
|
# a = Foo(1).foo(1, 2)
|
||||||
|
# b = Foo(1).foo(1, True)
|
||||||
|
# c = Foo(True).foo(True, 1)
|
||||||
|
# d = Foo(True).foo(True, True)
|
||||||
|
|
||||||
a = Foo(1).foo(1, 2)
|
a = Foo(1).foo(1, 2)
|
||||||
b = Foo(1).foo(1, True)
|
c = y.foo(True, 1)
|
||||||
c = Foo(True).foo(True, 1)
|
y = Foo(True)
|
||||||
d = Foo(True).foo(True, True)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(src)
|
print(src)
|
||||||
|
Loading…
Reference in New Issue
Block a user