fixed polymorphic methods

This commit is contained in:
pca006132 2021-07-09 17:11:34 +08:00
parent 1c17ed003e
commit e15e29d673
2 changed files with 60 additions and 59 deletions

View File

@ -125,11 +125,7 @@ class TVar(Type):
if k not in x.fields: if k not in x.fields:
raise UnificationError( raise UnificationError(
f'Cannot unify {y} with {x}') f'Cannot unify {y} with {x}')
if isinstance(v, TFunc) and not v.instantiated:
v = v.instantiate()
u = x.fields[k] u = x.fields[k]
if isinstance(u, TFunc) and not u.instantiated:
u = u.instantiate()
v.unify(u) v.unify(u)
if isinstance(x, TList): if isinstance(x, TList):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]: if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
@ -162,52 +158,59 @@ class FuncArg:
class TCall(Type): class TCall(Type):
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type): def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
self.posargs = posargs self.calls = [[posargs, kwargs, ret, None]]
self.kwargs = kwargs self.parent = self
self.ret = ret self.rank = 0
self.fun = TVar()
def check(self): def check(self):
self.fun.find().check() self.calls[0][3].check()
def find(self): def find(self):
if isinstance(self.fun.find(), TVar): root = self
return self parent = self.parent
return self.fun.find() while root is not parent and isinstance(parent, TCall):
_, parent = root, root.parent = parent, parent.parent
if parent.calls[0][3] is None:
return parent
return parent.calls[0][3]
def unify(self, other): def unify(self, other):
if not isinstance(self.fun.find(), TVar): y = self.find()
self.fun.unify(other) if y is not self:
y.unify(other)
return return
other = other.find() x = other.find()
if other is self: if x is y:
return return
if isinstance(other, TCall): if isinstance(x, TCall):
for a, b in zip(self.posargs, other.posargs): # standard union find
a.unify(b) if x.rank < y.rank:
for k, v in self.kwargs.items(): x, y = y, x
if k in other.kwargs: y.parent = x
other.kwargs[k].unify(v) if x.rank == y.rank:
else: x.rank += 1
other.kwargs[k] = v x.calls += y.calls
for k, v in other.kwargs.items(): elif isinstance(x, TFunc):
if k not in self.kwargs: fn = x
self.kwargs[k] = v for i in range(len(y.calls)):
self.fun.unify(other.fun) posargs, kwargs, ret, _ = y.calls[i]
elif isinstance(other, TFunc): c = y.calls[i]
all_args = set(arg.name for arg in other.args) c[3] = fn
required = set(arg.name for arg in other.args if not if not x.instantiated:
fn = x.instantiate()
all_args = set(arg.name for arg in fn.args)
required = set(arg.name for arg in fn.args if not
arg.is_optional) arg.is_optional)
other.ret.unify(self.ret) fn.ret.unify(ret)
for i, v in enumerate(self.posargs): for i, v in enumerate(posargs):
arg = other.args[i] arg = fn.args[i]
arg.typ.unify(v) arg.typ.unify(v)
if arg.name in required: if arg.name in required:
required.remove(arg.name) required.remove(arg.name)
for k, v in self.kwargs.items(): for k, v in kwargs.items():
arg = next((arg for arg in other.args if arg.name == k), None) arg = next((arg for arg in fn.args if arg.name == k), None)
if arg is None: if arg is None:
raise UnificationError(f'Unknown kwarg {k}') raise UnificationError(f'Unknown kwarg {k}')
if k not in all_args: if k not in all_args:
@ -218,7 +221,6 @@ class TCall(Type):
all_args.remove(k) all_args.remove(k)
if len(required) > 0: if len(required) > 0:
raise UnificationError(f'Missing arguments') raise UnificationError(f'Missing arguments')
self.fun.unify(other)
elif isinstance(other, TVar): elif isinstance(other, TVar):
other.unify(self) other.unify(self)
else: else:

View File

@ -5,7 +5,7 @@ from nac3_types import *
var = TVar([TInt, TBool]) var = TVar([TInt, TBool])
var2 = TVar([TInt]) var2 = TVar([TInt, TBool])
foo = TObj('Foo', { foo = TObj('Foo', {
'foo': TFunc([ 'foo': TFunc([
@ -26,14 +26,13 @@ for key, value in v.assignments.items():
print('-----------') print('-----------')
src = """ src = """
# a = Foo(1).foo(1, 2) a = f.foo(1, 2)
# b = Foo(1).foo(1, True) b = f.foo(1, True)
# c = Foo(True).foo(True, 1) c = g.foo(True, 1)
# d = Foo(True).foo(True, True) d = g.foo(True, True)
a = Foo(1).foo(1, 2) f = Foo(1)
c = y.foo(True, 1) g = Foo(1)
y = Foo(True)
""" """
print(src) print(src)