fixed polymorphic methods
This commit is contained in:
parent
1c17ed003e
commit
e15e29d673
@ -125,11 +125,7 @@ class TVar(Type):
|
|||||||
if k not in x.fields:
|
if k not in x.fields:
|
||||||
raise UnificationError(
|
raise UnificationError(
|
||||||
f'Cannot unify {y} with {x}')
|
f'Cannot unify {y} with {x}')
|
||||||
if isinstance(v, TFunc) and not v.instantiated:
|
|
||||||
v = v.instantiate()
|
|
||||||
u = x.fields[k]
|
u = x.fields[k]
|
||||||
if isinstance(u, TFunc) and not u.instantiated:
|
|
||||||
u = u.instantiate()
|
|
||||||
v.unify(u)
|
v.unify(u)
|
||||||
if isinstance(x, TList):
|
if isinstance(x, TList):
|
||||||
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
||||||
@ -162,52 +158,59 @@ class FuncArg:
|
|||||||
|
|
||||||
class TCall(Type):
|
class TCall(Type):
|
||||||
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
|
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
|
||||||
self.posargs = posargs
|
self.calls = [[posargs, kwargs, ret, None]]
|
||||||
self.kwargs = kwargs
|
self.parent = self
|
||||||
self.ret = ret
|
self.rank = 0
|
||||||
self.fun = TVar()
|
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.fun.find().check()
|
self.calls[0][3].check()
|
||||||
|
|
||||||
def find(self):
|
def find(self):
|
||||||
if isinstance(self.fun.find(), TVar):
|
root = self
|
||||||
return self
|
parent = self.parent
|
||||||
return self.fun.find()
|
while root is not parent and isinstance(parent, TCall):
|
||||||
|
_, parent = root, root.parent = parent, parent.parent
|
||||||
|
if parent.calls[0][3] is None:
|
||||||
|
return parent
|
||||||
|
return parent.calls[0][3]
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
if not isinstance(self.fun.find(), TVar):
|
y = self.find()
|
||||||
self.fun.unify(other)
|
if y is not self:
|
||||||
|
y.unify(other)
|
||||||
return
|
return
|
||||||
|
|
||||||
other = other.find()
|
x = other.find()
|
||||||
if other is self:
|
if x is y:
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(other, TCall):
|
if isinstance(x, TCall):
|
||||||
for a, b in zip(self.posargs, other.posargs):
|
# standard union find
|
||||||
a.unify(b)
|
if x.rank < y.rank:
|
||||||
for k, v in self.kwargs.items():
|
x, y = y, x
|
||||||
if k in other.kwargs:
|
y.parent = x
|
||||||
other.kwargs[k].unify(v)
|
if x.rank == y.rank:
|
||||||
else:
|
x.rank += 1
|
||||||
other.kwargs[k] = v
|
x.calls += y.calls
|
||||||
for k, v in other.kwargs.items():
|
elif isinstance(x, TFunc):
|
||||||
if k not in self.kwargs:
|
fn = x
|
||||||
self.kwargs[k] = v
|
for i in range(len(y.calls)):
|
||||||
self.fun.unify(other.fun)
|
posargs, kwargs, ret, _ = y.calls[i]
|
||||||
elif isinstance(other, TFunc):
|
c = y.calls[i]
|
||||||
all_args = set(arg.name for arg in other.args)
|
c[3] = fn
|
||||||
required = set(arg.name for arg in other.args if not
|
if not x.instantiated:
|
||||||
|
fn = x.instantiate()
|
||||||
|
all_args = set(arg.name for arg in fn.args)
|
||||||
|
required = set(arg.name for arg in fn.args if not
|
||||||
arg.is_optional)
|
arg.is_optional)
|
||||||
other.ret.unify(self.ret)
|
fn.ret.unify(ret)
|
||||||
for i, v in enumerate(self.posargs):
|
for i, v in enumerate(posargs):
|
||||||
arg = other.args[i]
|
arg = fn.args[i]
|
||||||
arg.typ.unify(v)
|
arg.typ.unify(v)
|
||||||
if arg.name in required:
|
if arg.name in required:
|
||||||
required.remove(arg.name)
|
required.remove(arg.name)
|
||||||
for k, v in self.kwargs.items():
|
for k, v in kwargs.items():
|
||||||
arg = next((arg for arg in other.args if arg.name == k), None)
|
arg = next((arg for arg in fn.args if arg.name == k), None)
|
||||||
if arg is None:
|
if arg is None:
|
||||||
raise UnificationError(f'Unknown kwarg {k}')
|
raise UnificationError(f'Unknown kwarg {k}')
|
||||||
if k not in all_args:
|
if k not in all_args:
|
||||||
@ -218,7 +221,6 @@ class TCall(Type):
|
|||||||
all_args.remove(k)
|
all_args.remove(k)
|
||||||
if len(required) > 0:
|
if len(required) > 0:
|
||||||
raise UnificationError(f'Missing arguments')
|
raise UnificationError(f'Missing arguments')
|
||||||
self.fun.unify(other)
|
|
||||||
elif isinstance(other, TVar):
|
elif isinstance(other, TVar):
|
||||||
other.unify(self)
|
other.unify(self)
|
||||||
else:
|
else:
|
||||||
|
@ -5,7 +5,7 @@ from nac3_types import *
|
|||||||
|
|
||||||
|
|
||||||
var = TVar([TInt, TBool])
|
var = TVar([TInt, TBool])
|
||||||
var2 = TVar([TInt])
|
var2 = TVar([TInt, TBool])
|
||||||
|
|
||||||
foo = TObj('Foo', {
|
foo = TObj('Foo', {
|
||||||
'foo': TFunc([
|
'foo': TFunc([
|
||||||
@ -26,14 +26,13 @@ for key, value in v.assignments.items():
|
|||||||
print('-----------')
|
print('-----------')
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
# a = Foo(1).foo(1, 2)
|
a = f.foo(1, 2)
|
||||||
# b = Foo(1).foo(1, True)
|
b = f.foo(1, True)
|
||||||
# c = Foo(True).foo(True, 1)
|
c = g.foo(True, 1)
|
||||||
# d = Foo(True).foo(True, True)
|
d = g.foo(True, True)
|
||||||
|
|
||||||
a = Foo(1).foo(1, 2)
|
f = Foo(1)
|
||||||
c = y.foo(True, 1)
|
g = Foo(1)
|
||||||
y = Foo(True)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(src)
|
print(src)
|
||||||
|
Loading…
Reference in New Issue
Block a user