virtual class

This commit is contained in:
pca006132 2020-12-21 13:53:59 +08:00 committed by pca006132
parent 53e82f5603
commit 4ac5ec8b04
3 changed files with 40 additions and 4 deletions

View File

@ -48,8 +48,20 @@ def find_subst(ctx: dict[str, Type],
else: else:
return f"{a} can take values other than {b}" return f"{a} can take values other than {b}"
# TODO: virtual type is not handled currently if isinstance(b, VirtualClassType):
# we need to access the class dictionary to handle this if isinstance(a, ClassType):
parents = [a]
elif isinstance(a, VirtualClassType):
parents = [a.base]
else:
raise CustomError("cannot substitute non-class-type into virtual class")
while len(parents) > 0:
current = parents.pop(0)
if current == b.base:
return sub
else:
parents += current.parents
if type(a) == type(b): if type(a) == type(b):
if isinstance(a, ParametricType): if isinstance(a, ParametricType):
if len(a.params) != len(b.params): if len(a.params) != len(b.params):
@ -62,9 +74,10 @@ def find_subst(ctx: dict[str, Type],
if a.name == b.name: if a.name == b.name:
return sub return sub
elif isinstance(a, VirtualClassType): elif isinstance(a, VirtualClassType):
return find_subst(a.base, b.base) return find_subst(ctx, sub, a.base, b.base)
else: else:
raise Exception() raise Exception()
return f"{a} != {b}" return f"{a} != {b}"
def resolve_call(obj, def resolve_call(obj,

View File

@ -162,7 +162,7 @@ def parse_call(ctx: Context,
obj = None obj = None
f = None f = None
if isinstance(node.func, ast.Attribute): if isinstance(node.func, ast.Attribute):
obj = parse_expr(node.func.value) obj = parse_expr(ctx, sym_table, node.func.value)
f = node.func.attr f = node.func.attr
elif isinstance(node.func, ast.Name): elif isinstance(node.func, ast.Name):
f = node.func.id f = node.func.id

View File

@ -111,8 +111,21 @@ class Foo:
def __ge__(self, other: Foo) -> bool: def __ge__(self, other: Foo) -> bool:
pass pass
def foo(self) -> self:
pass
class Bar(Foo):
def foo(self) -> self:
pass
def find(ls: list[I], x: I) -> int32: def find(ls: list[I], x: I) -> int32:
pass pass
def foobar(a: virtual[Foo]) -> virtual[Foo]:
pass
def bar(a: virtual[Bar]) -> virtual[Bar]:
pass
""" """
ctx, _ = parse_top_level(ctx, ast.parse(test_classes)) ctx, _ = parse_top_level(ctx, ast.parse(test_classes))
test_expr('Foo(1) + Foo(1)', {}) test_expr('Foo(1) + Foo(1)', {})
@ -120,3 +133,13 @@ test_expr('Foo(1) + Foo(1) < Foo(2) + Foo(3) < Foo(4)', {})
test_expr('find([1, 2, 3], 1)', {}) test_expr('find([1, 2, 3], 1)', {})
test_expr('find([], 1)', {}) test_expr('find([], 1)', {})
test_expr('Foo(1).foo()', {})
test_expr('foobar(1)', {})
test_expr('foobar(Foo(1))', {})
test_expr('foobar(Bar())', {})
test_expr('foobar(foobar(Foo(1)))', {})
test_expr('bar(foobar(Foo(1)))', {})
test_expr('foobar(bar(Foo(1)))', {})
test_expr('foobar(bar(Bar()))', {})