fixed polymorphic methods
This commit is contained in:
parent
1c17ed003e
commit
e15e29d673
@ -125,11 +125,7 @@ class TVar(Type):
|
||||
if k not in x.fields:
|
||||
raise UnificationError(
|
||||
f'Cannot unify {y} with {x}')
|
||||
if isinstance(v, TFunc) and not v.instantiated:
|
||||
v = v.instantiate()
|
||||
u = x.fields[k]
|
||||
if isinstance(u, TFunc) and not u.instantiated:
|
||||
u = u.instantiate()
|
||||
v.unify(u)
|
||||
if isinstance(x, TList):
|
||||
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
||||
@ -162,63 +158,69 @@ class FuncArg:
|
||||
|
||||
class TCall(Type):
|
||||
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
|
||||
self.posargs = posargs
|
||||
self.kwargs = kwargs
|
||||
self.ret = ret
|
||||
self.fun = TVar()
|
||||
self.calls = [[posargs, kwargs, ret, None]]
|
||||
self.parent = self
|
||||
self.rank = 0
|
||||
|
||||
def check(self):
|
||||
self.fun.find().check()
|
||||
self.calls[0][3].check()
|
||||
|
||||
def find(self):
|
||||
if isinstance(self.fun.find(), TVar):
|
||||
return self
|
||||
return self.fun.find()
|
||||
root = self
|
||||
parent = self.parent
|
||||
while root is not parent and isinstance(parent, TCall):
|
||||
_, parent = root, root.parent = parent, parent.parent
|
||||
if parent.calls[0][3] is None:
|
||||
return parent
|
||||
return parent.calls[0][3]
|
||||
|
||||
def unify(self, other):
|
||||
if not isinstance(self.fun.find(), TVar):
|
||||
self.fun.unify(other)
|
||||
y = self.find()
|
||||
if y is not self:
|
||||
y.unify(other)
|
||||
return
|
||||
|
||||
other = other.find()
|
||||
if other is self:
|
||||
x = other.find()
|
||||
if x is y:
|
||||
return
|
||||
|
||||
if isinstance(other, TCall):
|
||||
for a, b in zip(self.posargs, other.posargs):
|
||||
a.unify(b)
|
||||
for k, v in self.kwargs.items():
|
||||
if k in other.kwargs:
|
||||
other.kwargs[k].unify(v)
|
||||
else:
|
||||
other.kwargs[k] = v
|
||||
for k, v in other.kwargs.items():
|
||||
if k not in self.kwargs:
|
||||
self.kwargs[k] = v
|
||||
self.fun.unify(other.fun)
|
||||
elif isinstance(other, TFunc):
|
||||
all_args = set(arg.name for arg in other.args)
|
||||
required = set(arg.name for arg in other.args if not
|
||||
arg.is_optional)
|
||||
other.ret.unify(self.ret)
|
||||
for i, v in enumerate(self.posargs):
|
||||
arg = other.args[i]
|
||||
arg.typ.unify(v)
|
||||
if arg.name in required:
|
||||
required.remove(arg.name)
|
||||
for k, v in self.kwargs.items():
|
||||
arg = next((arg for arg in other.args if arg.name == k), None)
|
||||
if arg is None:
|
||||
raise UnificationError(f'Unknown kwarg {k}')
|
||||
if k not in all_args:
|
||||
raise UnificationError(f'Duplicated kwarg {k}')
|
||||
arg.typ.unify(v)
|
||||
if k in required:
|
||||
required.remove(k)
|
||||
all_args.remove(k)
|
||||
if len(required) > 0:
|
||||
raise UnificationError(f'Missing arguments')
|
||||
self.fun.unify(other)
|
||||
if isinstance(x, TCall):
|
||||
# standard union find
|
||||
if x.rank < y.rank:
|
||||
x, y = y, x
|
||||
y.parent = x
|
||||
if x.rank == y.rank:
|
||||
x.rank += 1
|
||||
x.calls += y.calls
|
||||
elif isinstance(x, TFunc):
|
||||
fn = x
|
||||
for i in range(len(y.calls)):
|
||||
posargs, kwargs, ret, _ = y.calls[i]
|
||||
c = y.calls[i]
|
||||
c[3] = fn
|
||||
if not x.instantiated:
|
||||
fn = x.instantiate()
|
||||
all_args = set(arg.name for arg in fn.args)
|
||||
required = set(arg.name for arg in fn.args if not
|
||||
arg.is_optional)
|
||||
fn.ret.unify(ret)
|
||||
for i, v in enumerate(posargs):
|
||||
arg = fn.args[i]
|
||||
arg.typ.unify(v)
|
||||
if arg.name in required:
|
||||
required.remove(arg.name)
|
||||
for k, v in kwargs.items():
|
||||
arg = next((arg for arg in fn.args if arg.name == k), None)
|
||||
if arg is None:
|
||||
raise UnificationError(f'Unknown kwarg {k}')
|
||||
if k not in all_args:
|
||||
raise UnificationError(f'Duplicated kwarg {k}')
|
||||
arg.typ.unify(v)
|
||||
if k in required:
|
||||
required.remove(k)
|
||||
all_args.remove(k)
|
||||
if len(required) > 0:
|
||||
raise UnificationError(f'Missing arguments')
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
else:
|
||||
|
@ -5,7 +5,7 @@ from nac3_types import *
|
||||
|
||||
|
||||
var = TVar([TInt, TBool])
|
||||
var2 = TVar([TInt])
|
||||
var2 = TVar([TInt, TBool])
|
||||
|
||||
foo = TObj('Foo', {
|
||||
'foo': TFunc([
|
||||
@ -26,14 +26,13 @@ for key, value in v.assignments.items():
|
||||
print('-----------')
|
||||
|
||||
src = """
|
||||
# a = Foo(1).foo(1, 2)
|
||||
# b = Foo(1).foo(1, True)
|
||||
# c = Foo(True).foo(True, 1)
|
||||
# d = Foo(True).foo(True, True)
|
||||
a = f.foo(1, 2)
|
||||
b = f.foo(1, True)
|
||||
c = g.foo(True, 1)
|
||||
d = g.foo(True, True)
|
||||
|
||||
a = Foo(1).foo(1, 2)
|
||||
c = y.foo(True, 1)
|
||||
y = Foo(True)
|
||||
f = Foo(1)
|
||||
g = Foo(1)
|
||||
"""
|
||||
|
||||
print(src)
|
||||
|
Loading…
Reference in New Issue
Block a user