fixed inheritance check

This commit is contained in:
pca006132 2020-12-23 13:43:34 +08:00 committed by pca006132
parent b2ec75e157
commit 461d403cce
4 changed files with 25 additions and 3 deletions

View File

@ -14,11 +14,20 @@ def class_fixup(c: ClassType):
if m in c.methods: if m in c.methods:
old = p.methods[m] old = p.methods[m]
new = c.methods[m] new = c.methods[m]
if old[0] != new[0] or old[1] != new[1]: for a, b in zip(old[0], new[0]):
if a != b:
raise CustomError(f'{m} is different in {c.name} and {p.name}')
if old[1] != new[1]:
# actually, we should check for equality *modulo variable renaming* # actually, we should check for equality *modulo variable renaming*
raise CustomError(f'incorrect method signature for {m} in {c.name}') raise CustomError(f'{m} is different in {c.name} and {p.name}')
else: else:
c.methods[m] = p.methods[m] c.methods[m] = p.methods[m]
for f in p.fields:
if f in c.fields:
if p.fields[f] != c.fields[f]:
raise CustomError(f'{f} is different in {c.name} and {p.name}')
else:
c.fields[f] = p.fields[f]
c.checking = False c.checking = False
c.checked = True c.checked = True

View File

@ -1,10 +1,11 @@
import ast import ast
import sys import sys
from helper import CustomError from helper import CustomError
from type_def import SelfType from type_def import SelfType, ClassType
from parse_stmt import parse_stmts from parse_stmt import parse_stmts
from primitives import simplest_ctx from primitives import simplest_ctx
from top_level import parse_top_level from top_level import parse_top_level
from inheritance import class_fixup
if len(sys.argv) != 2: if len(sys.argv) != 2:
print('please pass the python script name as argument') print('please pass the python script name as argument')
@ -17,6 +18,10 @@ tree = ast.parse(source, filename=sys.argv[1])
try: try:
ctx, fns = parse_top_level(simplest_ctx, tree) ctx, fns = parse_top_level(simplest_ctx, tree)
for c in ctx.types.values():
if isinstance(c, ClassType):
class_fixup(c)
for c, name, fn in fns: for c, name, fn in fns:
if c is None: if c is None:
params, result, _ = ctx.functions[name] params, result, _ = ctx.functions[name]

View File

@ -199,6 +199,11 @@ def parse_return_stmt(ctx: Context,
raise CustomError('no return value is allowed', node) raise CustomError('no return value is allowed', node)
return {}, {}, True return {}, {}, True
ty = parse_expr(ctx, sym_table, node.value) ty = parse_expr(ctx, sym_table, node.value)
if isinstance(node.value, ast.Name) and \
node.value.id == 'self' and \
'self' in sym_table and \
isinstance(return_ty, SelfType):
return {}, {}, True
if ty != return_ty: if ty != return_ty:
if isinstance(return_ty, TypeVariable): if isinstance(return_ty, TypeVariable):
if len(return_ty.constraints) == 1 and \ if len(return_ty.constraints) == 1 and \

View File

@ -94,6 +94,9 @@ class SelfType(Type):
def __str__(self): def __str__(self):
return 'self' return 'self'
def __eq__(self, other):
return type(self) == type(other)
class VirtualClassType(Type): class VirtualClassType(Type):
base: ClassType base: ClassType