added simple lifetime check

This commit is contained in:
pca006132 2021-01-07 11:57:28 +08:00
parent 3457e594ec
commit 84d09f1fd1
5 changed files with 195 additions and 55 deletions

View File

@ -1,30 +1,65 @@
I = TypeVar('I', int32, float, Vec) T = TypeVar('T')
class Vec: class Foo:
v: list[int32] a: list[int32]
def __init__(self, v: list[int32]): b: list[int32]
self.v = v
def __add__(self, other: I) -> Vec: def choose(t: bool, a: T, b: T) -> T:
if type(other) == int32: if t:
return Vec([v + other for v in self.v])
elif type(other) == float:
return Vec([v + int32(other) for v in self.v])
else:
return Vec([self.v[i] + other.v[i] for i in range(len(self.v))])
def get(self, index: int32) -> int32:
return self.v.head()
T = TypeVar('T', int32, list[int32])
def add(a: int32, b: T) -> int32:
if type(b) == int32:
return a + b
else:
for x in b:
a = add(a, x)
return a return a
else:
return b
def set_list(ls: list[T], a: T):
# this should fail
l2 = ls
l2[-1] = a
def get_foo(a: Foo) -> list[int32]:
return a.a
def set_foo(a: Foo, b: Foo):
a.a[0] = b.a[0]
if True:
c = b
# this should fail
c.a = a.a
def set_foo2(a: Foo, b: Foo):
a.a[0] = b.a[0]
if True:
c = [Foo()]
c[0] = a
# this should fail
c[0].a = b.a
def set_foo3(a: Foo, b: Foo):
a.a[0] = b.a[0]
if True:
c = [Foo()]
c[0] = a
# this should fail
c[0].a = get_foo(b)
def set_foo4(a: Foo, b: Foo):
a.a[0] = b.a[0]
if True:
c = [Foo()]
d = c
d[0] = a
# this should fail
c[0].a = get_foo(b)
def set_foo5(a: Foo, b: Foo):
a.a[0] = b.a[0]
if True:
c = [Foo()]
d = c
e = d
f = e
f[0] = a
# this should fail
c[0].a = get_foo(b)

86
toy-impl/lifetime.py Normal file
View File

@ -0,0 +1,86 @@
import ast
from type_def import PrimitiveType
from parse_expr import parse_expr
class Lifetime:
low: int
original: int
def __init__(self, scope):
self.low = self.original = scope
self.parent = None
def fold(self):
while self.parent is not None:
self.low = self.parent.low
self.original = self.parent.original
self.parent = self.parent.parent
return self
def ok(self, other):
self.fold()
if other == None:
return False
other.fold()
return self.low >= other.original and \
(other.original != self.low or self.low != 1)
def __str__(self):
self.fold()
return f'({self.low}, {self.original})'
def assign_expr(
scope: int,
sym_table: dict[str, Lifetime],
expr: ast.expr):
if isinstance(expr, ast.Expression):
body = expr.body
else:
body = expr
if isinstance(body.type, PrimitiveType):
body.lifetime = None
elif isinstance(body, ast.Attribute):
body.lifetime = assign_expr(scope, sym_table, body.value)
elif isinstance(body, ast.Subscript):
body.lifetime = assign_expr(scope, sym_table, body.value)
elif isinstance(body, ast.Name):
if body.id in sym_table:
body.lifetime = sym_table[body.id]
else:
body.lifetime = Lifetime(scope)
sym_table[body.id] = body.lifetime
else:
body.lifetime = Lifetime(scope)
return body.lifetime
def assign_stmt(
scope: int,
sym_table: dict[str, Lifetime],
nodes):
for node in nodes:
if isinstance(node, ast.Assign):
b = assign_expr(scope, sym_table, node.value)
for target in node.targets:
a = assign_expr(scope, sym_table, target)
if a == None and b == None:
continue
if not a.ok(b):
print(ast.unparse(node))
print(f'{a} <- {b}')
assert False
a.low = min(a.low, b.low)
a.original = max(a.original, b.original)
b.parent = a
elif isinstance(node, ast.If) or isinstance(node, ast.While):
assign_stmt(scope + 1, sym_table, node.body)
assign_stmt(scope + 1, sym_table, node.orelse)
elif isinstance(node, ast.Return):
a = assign_expr(scope, sym_table, node.value)
if a != None and a.fold().original > 1:
print(ast.unparse(node))
print(a)
assert False

View File

@ -7,6 +7,7 @@ from parse_stmt import parse_stmts
from primitives import simplest_ctx from primitives import simplest_ctx
from top_level import parse_top_level from top_level import parse_top_level
from inheritance import class_fixup from inheritance import class_fixup
from lifetime import Lifetime, assign_stmt
if len(sys.argv) != 2: if len(sys.argv) != 2:
print('please pass the python script name as argument') print('please pass the python script name as argument')
@ -40,9 +41,20 @@ try:
if isinstance(ty, SelfType): if isinstance(ty, SelfType):
ty = ctx.types[c] ty = ctx.types[c]
sym_table[n.arg] = ty.subst(subst) sym_table[n.arg] = ty.subst(subst)
_, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body) try:
if result is not None and not returned: print()
raise CustomError('Function may have no return value', fn) print('checking:')
print(ast.unparse(fn))
print('typecheck...')
_, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body)
if result is not None and not returned:
raise CustomError('Function may have no return value', fn)
print('lifetime check...')
sym_table = {k: Lifetime(1) for k in sym_table}
assign_stmt(2, sym_table, fn.body)
print('OK!')
except AssertionError:
pass
except CustomError as e: except CustomError as e:
print('Error while type checking:') print('Error while type checking:')
print(e.msg) print(e.msg)

View File

@ -17,33 +17,35 @@ def parse_expr(ctx: Context,
else: else:
body = expr body = expr
if isinstance(body, ast.Constant): if isinstance(body, ast.Constant):
return parse_constant(ctx, sym_table, body) result = parse_constant(ctx, sym_table, body)
if isinstance(body, ast.UnaryOp): elif isinstance(body, ast.UnaryOp):
return parse_unary_ops(ctx, sym_table, body) result = parse_unary_ops(ctx, sym_table, body)
if isinstance(body, ast.BinOp): elif isinstance(body, ast.BinOp):
return parse_bin_ops(ctx, sym_table, body) result = parse_bin_ops(ctx, sym_table, body)
if isinstance(body, ast.Name): elif isinstance(body, ast.Name):
return parse_name(ctx, sym_table, body) result = parse_name(ctx, sym_table, body)
if isinstance(body, ast.List): elif isinstance(body, ast.List):
return parse_list(ctx, sym_table, body) result = parse_list(ctx, sym_table, body)
if isinstance(body, ast.Tuple): elif isinstance(body, ast.Tuple):
return parse_tuple(ctx, sym_table, body) result = parse_tuple(ctx, sym_table, body)
if isinstance(body, ast.Attribute): elif isinstance(body, ast.Attribute):
return parse_attribute(ctx, sym_table, body) result = parse_attribute(ctx, sym_table, body)
if isinstance(body, ast.BoolOp): elif isinstance(body, ast.BoolOp):
return parse_bool_ops(ctx, sym_table, body) result = parse_bool_ops(ctx, sym_table, body)
if isinstance(body, ast.Compare): elif isinstance(body, ast.Compare):
return parse_compare(ctx, sym_table, body) result = parse_compare(ctx, sym_table, body)
if isinstance(body, ast.Call): elif isinstance(body, ast.Call):
return parse_call(ctx, sym_table, body) result = parse_call(ctx, sym_table, body)
if isinstance(body, ast.Subscript): elif isinstance(body, ast.Subscript):
return parse_subscript(ctx, sym_table, body) result = parse_subscript(ctx, sym_table, body)
if isinstance(body, ast.IfExp): elif isinstance(body, ast.IfExp):
return parse_if_expr(ctx, sym_table, body) result = parse_if_expr(ctx, sym_table, body)
if isinstance(body, ast.ListComp): elif isinstance(body, ast.ListComp):
return parse_list_comprehension(ctx, sym_table, body) result = parse_list_comprehension(ctx, sym_table, body)
raise CustomError(f'{body} is not yet supported', body) else:
raise CustomError(f'{body} is not yet supported', body)
body.type = result
return result
def get_unary_op(op): def get_unary_op(op):
if isinstance(op, ast.UAdd): if isinstance(op, ast.UAdd):
@ -238,6 +240,7 @@ def parse_simple_binding(name, ty):
if isinstance(name, ast.Name): if isinstance(name, ast.Name):
if name.id == '_': if name.id == '_':
return {} return {}
name.type = ty
return {name.id: ty} return {name.id: ty}
elif isinstance(name, ast.Tuple): elif isinstance(name, ast.Tuple):
if not isinstance(ty, TupleType): if not isinstance(ty, TupleType):

View File

@ -50,15 +50,18 @@ def get_target_type(ctx: Context,
i = parse_expr(ctx, sym_table, target.slice) i = parse_expr(ctx, sym_table, target.slice)
if i != ctx.types['int32']: if i != ctx.types['int32']:
raise CustomError(f'index must be int32', target.slice) raise CustomError(f'index must be int32', target.slice)
target.type = t.params[0]
return t.params[0] return t.params[0]
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
t = get_target_type(ctx, sym_table, used_sym_table, target.value) t = get_target_type(ctx, sym_table, used_sym_table, target.value)
if target.attr not in t.fields: if target.attr not in t.fields:
raise CustomError(f'{t} has no field {target.attr}', target) raise CustomError(f'{t} has no field {target.attr}', target)
target.type = t.fields[target.attr]
return t.fields[target.attr] return t.fields[target.attr]
elif isinstance(target, ast.Name): elif isinstance(target, ast.Name):
if target.id not in sym_table: if target.id not in sym_table:
raise CustomError(f'unbounded {target.id}', target) raise CustomError(f'unbounded {target.id}', target)
target.type = sym_table[target.id]
return sym_table[target.id] return sym_table[target.id]
else: else:
raise CustomError(f'assignment to {target} is not supported', target) raise CustomError(f'assignment to {target} is not supported', target)
@ -76,6 +79,7 @@ def parse_stmt_binding(ctx: Context,
f'but is now {ty}', target) f'but is now {ty}', target)
if target.id == '_': if target.id == '_':
return {} return {}
target.type = ty
return {target.id: ty} return {target.id: ty}
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
if not isinstance(ty, TupleType): if not isinstance(ty, TupleType):