fixed polymorphic methods

This commit is contained in:
pca006132 2021-07-09 17:11:34 +08:00
parent 1c17ed003e
commit e15e29d673
2 changed files with 60 additions and 59 deletions

View File

@ -125,11 +125,7 @@ class TVar(Type):
if k not in x.fields:
raise UnificationError(
f'Cannot unify {y} with {x}')
if isinstance(v, TFunc) and not v.instantiated:
v = v.instantiate()
u = x.fields[k]
if isinstance(u, TFunc) and not u.instantiated:
u = u.instantiate()
v.unify(u)
if isinstance(x, TList):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
@ -162,63 +158,69 @@ class FuncArg:
class TCall(Type):
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
self.posargs = posargs
self.kwargs = kwargs
self.ret = ret
self.fun = TVar()
self.calls = [[posargs, kwargs, ret, None]]
self.parent = self
self.rank = 0
def check(self):
self.fun.find().check()
self.calls[0][3].check()
def find(self):
if isinstance(self.fun.find(), TVar):
return self
return self.fun.find()
root = self
parent = self.parent
while root is not parent and isinstance(parent, TCall):
_, parent = root, root.parent = parent, parent.parent
if parent.calls[0][3] is None:
return parent
return parent.calls[0][3]
def unify(self, other):
if not isinstance(self.fun.find(), TVar):
self.fun.unify(other)
y = self.find()
if y is not self:
y.unify(other)
return
other = other.find()
if other is self:
x = other.find()
if x is y:
return
if isinstance(other, TCall):
for a, b in zip(self.posargs, other.posargs):
a.unify(b)
for k, v in self.kwargs.items():
if k in other.kwargs:
other.kwargs[k].unify(v)
else:
other.kwargs[k] = v
for k, v in other.kwargs.items():
if k not in self.kwargs:
self.kwargs[k] = v
self.fun.unify(other.fun)
elif isinstance(other, TFunc):
all_args = set(arg.name for arg in other.args)
required = set(arg.name for arg in other.args if not
arg.is_optional)
other.ret.unify(self.ret)
for i, v in enumerate(self.posargs):
arg = other.args[i]
arg.typ.unify(v)
if arg.name in required:
required.remove(arg.name)
for k, v in self.kwargs.items():
arg = next((arg for arg in other.args if arg.name == k), None)
if arg is None:
raise UnificationError(f'Unknown kwarg {k}')
if k not in all_args:
raise UnificationError(f'Duplicated kwarg {k}')
arg.typ.unify(v)
if k in required:
required.remove(k)
all_args.remove(k)
if len(required) > 0:
raise UnificationError(f'Missing arguments')
self.fun.unify(other)
if isinstance(x, TCall):
# standard union find
if x.rank < y.rank:
x, y = y, x
y.parent = x
if x.rank == y.rank:
x.rank += 1
x.calls += y.calls
elif isinstance(x, TFunc):
fn = x
for i in range(len(y.calls)):
posargs, kwargs, ret, _ = y.calls[i]
c = y.calls[i]
c[3] = fn
if not x.instantiated:
fn = x.instantiate()
all_args = set(arg.name for arg in fn.args)
required = set(arg.name for arg in fn.args if not
arg.is_optional)
fn.ret.unify(ret)
for i, v in enumerate(posargs):
arg = fn.args[i]
arg.typ.unify(v)
if arg.name in required:
required.remove(arg.name)
for k, v in kwargs.items():
arg = next((arg for arg in fn.args if arg.name == k), None)
if arg is None:
raise UnificationError(f'Unknown kwarg {k}')
if k not in all_args:
raise UnificationError(f'Duplicated kwarg {k}')
arg.typ.unify(v)
if k in required:
required.remove(k)
all_args.remove(k)
if len(required) > 0:
raise UnificationError(f'Missing arguments')
elif isinstance(other, TVar):
other.unify(self)
else:

View File

@ -5,7 +5,7 @@ from nac3_types import *
var = TVar([TInt, TBool])
var2 = TVar([TInt])
var2 = TVar([TInt, TBool])
foo = TObj('Foo', {
'foo': TFunc([
@ -26,14 +26,13 @@ for key, value in v.assignments.items():
print('-----------')
src = """
# a = Foo(1).foo(1, 2)
# b = Foo(1).foo(1, True)
# c = Foo(True).foo(True, 1)
# d = Foo(True).foo(True, True)
a = f.foo(1, 2)
b = f.foo(1, True)
c = g.foo(True, 1)
d = g.foo(True, True)
a = Foo(1).foo(1, 2)
c = y.foo(True, 1)
y = Foo(True)
f = Foo(1)
g = Foo(1)
"""
print(src)