diff --git a/toy-impl/inheritance.py b/toy-impl/inheritance.py index 400d697..347c91a 100644 --- a/toy-impl/inheritance.py +++ b/toy-impl/inheritance.py @@ -14,11 +14,20 @@ def class_fixup(c: ClassType): if m in c.methods: old = p.methods[m] new = c.methods[m] - if old[0] != new[0] or old[1] != new[1]: + for a, b in zip(old[0], new[0]): + if a != b: + raise CustomError(f'{m} is different in {c.name} and {p.name}') + if old[1] != new[1]: # actually, we should check for equality *modulo variable renaming* - raise CustomError(f'incorrect method signature for {m} in {c.name}') + raise CustomError(f'{m} is different in {c.name} and {p.name}') else: c.methods[m] = p.methods[m] + for f in p.fields: + if f in c.fields: + if p.fields[f] != c.fields[f]: + raise CustomError(f'{f} is different in {c.name} and {p.name}') + else: + c.fields[f] = p.fields[f] c.checking = False c.checked = True diff --git a/toy-impl/main.py b/toy-impl/main.py index 1662f61..dee3b1a 100644 --- a/toy-impl/main.py +++ b/toy-impl/main.py @@ -1,10 +1,11 @@ import ast import sys from helper import CustomError -from type_def import SelfType +from type_def import SelfType, ClassType from parse_stmt import parse_stmts from primitives import simplest_ctx from top_level import parse_top_level +from inheritance import class_fixup if len(sys.argv) != 2: print('please pass the python script name as argument') @@ -17,6 +18,10 @@ tree = ast.parse(source, filename=sys.argv[1]) try: ctx, fns = parse_top_level(simplest_ctx, tree) + for c in ctx.types.values(): + if isinstance(c, ClassType): + class_fixup(c) + for c, name, fn in fns: if c is None: params, result, _ = ctx.functions[name] diff --git a/toy-impl/parse_stmt.py b/toy-impl/parse_stmt.py index 0ca2f7a..075b083 100644 --- a/toy-impl/parse_stmt.py +++ b/toy-impl/parse_stmt.py @@ -199,6 +199,11 @@ def parse_return_stmt(ctx: Context, raise CustomError('no return value is allowed', node) return {}, {}, True ty = parse_expr(ctx, sym_table, node.value) + if isinstance(node.value, ast.Name) and \ + node.value.id == 'self' and \ + 'self' in sym_table and \ + isinstance(return_ty, SelfType): + return {}, {}, True if ty != return_ty: if isinstance(return_ty, TypeVariable): if len(return_ty.constraints) == 1 and \ diff --git a/toy-impl/type_def.py b/toy-impl/type_def.py index a8d8d2e..024000b 100644 --- a/toy-impl/type_def.py +++ b/toy-impl/type_def.py @@ -94,6 +94,9 @@ class SelfType(Type): def __str__(self): return 'self' + def __eq__(self, other): + return type(self) == type(other) + class VirtualClassType(Type): base: ClassType