From 4ac5ec8b048e4e1cab7d418b92949b54733836ed Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 21 Dec 2020 13:53:59 +0800 Subject: [PATCH] virtual class --- toy-impl/inference.py | 19 ++++++++++++++++--- toy-impl/parse_expr.py | 2 +- toy-impl/test_expr.py | 23 +++++++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/toy-impl/inference.py b/toy-impl/inference.py index 30157e9..1d4441d 100644 --- a/toy-impl/inference.py +++ b/toy-impl/inference.py @@ -48,8 +48,20 @@ def find_subst(ctx: dict[str, Type], else: return f"{a} can take values other than {b}" - # TODO: virtual type is not handled currently - # we need to access the class dictionary to handle this + if isinstance(b, VirtualClassType): + if isinstance(a, ClassType): + parents = [a] + elif isinstance(a, VirtualClassType): + parents = [a.base] + else: + raise CustomError("cannot substitute non-class-type into virtual class") + while len(parents) > 0: + current = parents.pop(0) + if current == b.base: + return sub + else: + parents += current.parents + if type(a) == type(b): if isinstance(a, ParametricType): if len(a.params) != len(b.params): @@ -62,9 +74,10 @@ def find_subst(ctx: dict[str, Type], if a.name == b.name: return sub elif isinstance(a, VirtualClassType): - return find_subst(a.base, b.base) + return find_subst(ctx, sub, a.base, b.base) else: raise Exception() + return f"{a} != {b}" def resolve_call(obj, diff --git a/toy-impl/parse_expr.py b/toy-impl/parse_expr.py index 7e01072..4fd3047 100644 --- a/toy-impl/parse_expr.py +++ b/toy-impl/parse_expr.py @@ -162,7 +162,7 @@ def parse_call(ctx: Context, obj = None f = None if isinstance(node.func, ast.Attribute): - obj = parse_expr(node.func.value) + obj = parse_expr(ctx, sym_table, node.func.value) f = node.func.attr elif isinstance(node.func, ast.Name): f = node.func.id diff --git a/toy-impl/test_expr.py b/toy-impl/test_expr.py index a72c39c..102f928 100644 --- a/toy-impl/test_expr.py +++ b/toy-impl/test_expr.py @@ -111,8 +111,21 @@ class Foo: def __ge__(self, other: Foo) -> bool: pass + def foo(self) -> self: + pass + +class Bar(Foo): + def foo(self) -> self: + pass + def find(ls: list[I], x: I) -> int32: pass + +def foobar(a: virtual[Foo]) -> virtual[Foo]: + pass + +def bar(a: virtual[Bar]) -> virtual[Bar]: + pass """ ctx, _ = parse_top_level(ctx, ast.parse(test_classes)) test_expr('Foo(1) + Foo(1)', {}) @@ -120,3 +133,13 @@ test_expr('Foo(1) + Foo(1) < Foo(2) + Foo(3) < Foo(4)', {}) test_expr('find([1, 2, 3], 1)', {}) test_expr('find([], 1)', {}) +test_expr('Foo(1).foo()', {}) +test_expr('foobar(1)', {}) +test_expr('foobar(Foo(1))', {}) +test_expr('foobar(Bar())', {}) +test_expr('foobar(foobar(Foo(1)))', {}) +test_expr('bar(foobar(Foo(1)))', {}) +test_expr('foobar(bar(Foo(1)))', {}) +test_expr('foobar(bar(Bar()))', {}) + +