From 84d09f1fd12aee31845f227ce4712c0febf51c3d Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 7 Jan 2021 11:57:28 +0800 Subject: [PATCH] added simple lifetime check --- toy-impl/examples/a.py | 85 +++++++++++++++++++++++++++++------------ toy-impl/lifetime.py | 86 ++++++++++++++++++++++++++++++++++++++++++ toy-impl/main.py | 18 +++++++-- toy-impl/parse_expr.py | 57 +++++++++++++++------------- toy-impl/parse_stmt.py | 4 ++ 5 files changed, 195 insertions(+), 55 deletions(-) create mode 100644 toy-impl/lifetime.py diff --git a/toy-impl/examples/a.py b/toy-impl/examples/a.py index 3b65f29..a5f6775 100644 --- a/toy-impl/examples/a.py +++ b/toy-impl/examples/a.py @@ -1,30 +1,65 @@ -I = TypeVar('I', int32, float, Vec) +T = TypeVar('T') -class Vec: - v: list[int32] - def __init__(self, v: list[int32]): - self.v = v +class Foo: + a: list[int32] + b: list[int32] - def __add__(self, other: I) -> Vec: - if type(other) == int32: - return Vec([v + other for v in self.v]) - elif type(other) == float: - return Vec([v + int32(other) for v in self.v]) - else: - return Vec([self.v[i] + other.v[i] for i in range(len(self.v))]) - - def get(self, index: int32) -> int32: - return self.v.head() - - -T = TypeVar('T', int32, list[int32]) - -def add(a: int32, b: T) -> int32: - if type(b) == int32: - return a + b - else: - for x in b: - a = add(a, x) +def choose(t: bool, a: T, b: T) -> T: + if t: return a + else: + return b + +def set_list(ls: list[T], a: T): + # this should fail + l2 = ls + l2[-1] = a + +def get_foo(a: Foo) -> list[int32]: + return a.a + +def set_foo(a: Foo, b: Foo): + a.a[0] = b.a[0] + if True: + c = b + # this should fail + c.a = a.a + +def set_foo2(a: Foo, b: Foo): + a.a[0] = b.a[0] + if True: + c = [Foo()] + c[0] = a + # this should fail + c[0].a = b.a + +def set_foo3(a: Foo, b: Foo): + a.a[0] = b.a[0] + if True: + c = [Foo()] + c[0] = a + # this should fail + c[0].a = get_foo(b) + +def set_foo4(a: Foo, b: Foo): + a.a[0] = b.a[0] + if True: + c = [Foo()] + d = c + d[0] = a + # this should fail + c[0].a = get_foo(b) + +def set_foo5(a: Foo, b: Foo): + a.a[0] = b.a[0] + if True: + c = [Foo()] + d = c + e = d + f = e + f[0] = a + # this should fail + c[0].a = get_foo(b) + diff --git a/toy-impl/lifetime.py b/toy-impl/lifetime.py new file mode 100644 index 0000000..b4e9c9f --- /dev/null +++ b/toy-impl/lifetime.py @@ -0,0 +1,86 @@ +import ast +from type_def import PrimitiveType +from parse_expr import parse_expr + +class Lifetime: + low: int + original: int + + def __init__(self, scope): + self.low = self.original = scope + self.parent = None + + def fold(self): + while self.parent is not None: + self.low = self.parent.low + self.original = self.parent.original + self.parent = self.parent.parent + return self + + def ok(self, other): + self.fold() + if other == None: + return False + other.fold() + return self.low >= other.original and \ + (other.original != self.low or self.low != 1) + + def __str__(self): + self.fold() + return f'({self.low}, {self.original})' + + +def assign_expr( + scope: int, + sym_table: dict[str, Lifetime], + expr: ast.expr): + if isinstance(expr, ast.Expression): + body = expr.body + else: + body = expr + if isinstance(body.type, PrimitiveType): + body.lifetime = None + elif isinstance(body, ast.Attribute): + body.lifetime = assign_expr(scope, sym_table, body.value) + elif isinstance(body, ast.Subscript): + body.lifetime = assign_expr(scope, sym_table, body.value) + elif isinstance(body, ast.Name): + if body.id in sym_table: + body.lifetime = sym_table[body.id] + else: + body.lifetime = Lifetime(scope) + sym_table[body.id] = body.lifetime + else: + body.lifetime = Lifetime(scope) + return body.lifetime + + +def assign_stmt( + scope: int, + sym_table: dict[str, Lifetime], + nodes): + for node in nodes: + if isinstance(node, ast.Assign): + b = assign_expr(scope, sym_table, node.value) + for target in node.targets: + a = assign_expr(scope, sym_table, target) + if a == None and b == None: + continue + if not a.ok(b): + print(ast.unparse(node)) + print(f'{a} <- {b}') + assert False + a.low = min(a.low, b.low) + a.original = max(a.original, b.original) + b.parent = a + elif isinstance(node, ast.If) or isinstance(node, ast.While): + assign_stmt(scope + 1, sym_table, node.body) + assign_stmt(scope + 1, sym_table, node.orelse) + elif isinstance(node, ast.Return): + a = assign_expr(scope, sym_table, node.value) + if a != None and a.fold().original > 1: + print(ast.unparse(node)) + print(a) + assert False + + diff --git a/toy-impl/main.py b/toy-impl/main.py index 904a34c..f5a6348 100644 --- a/toy-impl/main.py +++ b/toy-impl/main.py @@ -7,6 +7,7 @@ from parse_stmt import parse_stmts from primitives import simplest_ctx from top_level import parse_top_level from inheritance import class_fixup +from lifetime import Lifetime, assign_stmt if len(sys.argv) != 2: print('please pass the python script name as argument') @@ -40,9 +41,20 @@ try: if isinstance(ty, SelfType): ty = ctx.types[c] sym_table[n.arg] = ty.subst(subst) - _, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body) - if result is not None and not returned: - raise CustomError('Function may have no return value', fn) + try: + print() + print('checking:') + print(ast.unparse(fn)) + print('typecheck...') + _, _, returned = parse_stmts(ctx, sym_table, sym_table, result, fn.body) + if result is not None and not returned: + raise CustomError('Function may have no return value', fn) + print('lifetime check...') + sym_table = {k: Lifetime(1) for k in sym_table} + assign_stmt(2, sym_table, fn.body) + print('OK!') + except AssertionError: + pass except CustomError as e: print('Error while type checking:') print(e.msg) diff --git a/toy-impl/parse_expr.py b/toy-impl/parse_expr.py index f404440..6b75c2a 100644 --- a/toy-impl/parse_expr.py +++ b/toy-impl/parse_expr.py @@ -17,33 +17,35 @@ def parse_expr(ctx: Context, else: body = expr if isinstance(body, ast.Constant): - return parse_constant(ctx, sym_table, body) - if isinstance(body, ast.UnaryOp): - return parse_unary_ops(ctx, sym_table, body) - if isinstance(body, ast.BinOp): - return parse_bin_ops(ctx, sym_table, body) - if isinstance(body, ast.Name): - return parse_name(ctx, sym_table, body) - if isinstance(body, ast.List): - return parse_list(ctx, sym_table, body) - if isinstance(body, ast.Tuple): - return parse_tuple(ctx, sym_table, body) - if isinstance(body, ast.Attribute): - return parse_attribute(ctx, sym_table, body) - if isinstance(body, ast.BoolOp): - return parse_bool_ops(ctx, sym_table, body) - if isinstance(body, ast.Compare): - return parse_compare(ctx, sym_table, body) - if isinstance(body, ast.Call): - return parse_call(ctx, sym_table, body) - if isinstance(body, ast.Subscript): - return parse_subscript(ctx, sym_table, body) - if isinstance(body, ast.IfExp): - return parse_if_expr(ctx, sym_table, body) - if isinstance(body, ast.ListComp): - return parse_list_comprehension(ctx, sym_table, body) - raise CustomError(f'{body} is not yet supported', body) - + result = parse_constant(ctx, sym_table, body) + elif isinstance(body, ast.UnaryOp): + result = parse_unary_ops(ctx, sym_table, body) + elif isinstance(body, ast.BinOp): + result = parse_bin_ops(ctx, sym_table, body) + elif isinstance(body, ast.Name): + result = parse_name(ctx, sym_table, body) + elif isinstance(body, ast.List): + result = parse_list(ctx, sym_table, body) + elif isinstance(body, ast.Tuple): + result = parse_tuple(ctx, sym_table, body) + elif isinstance(body, ast.Attribute): + result = parse_attribute(ctx, sym_table, body) + elif isinstance(body, ast.BoolOp): + result = parse_bool_ops(ctx, sym_table, body) + elif isinstance(body, ast.Compare): + result = parse_compare(ctx, sym_table, body) + elif isinstance(body, ast.Call): + result = parse_call(ctx, sym_table, body) + elif isinstance(body, ast.Subscript): + result = parse_subscript(ctx, sym_table, body) + elif isinstance(body, ast.IfExp): + result = parse_if_expr(ctx, sym_table, body) + elif isinstance(body, ast.ListComp): + result = parse_list_comprehension(ctx, sym_table, body) + else: + raise CustomError(f'{body} is not yet supported', body) + body.type = result + return result def get_unary_op(op): if isinstance(op, ast.UAdd): @@ -238,6 +240,7 @@ def parse_simple_binding(name, ty): if isinstance(name, ast.Name): if name.id == '_': return {} + name.type = ty return {name.id: ty} elif isinstance(name, ast.Tuple): if not isinstance(ty, TupleType): diff --git a/toy-impl/parse_stmt.py b/toy-impl/parse_stmt.py index 82f1222..116562b 100644 --- a/toy-impl/parse_stmt.py +++ b/toy-impl/parse_stmt.py @@ -50,15 +50,18 @@ def get_target_type(ctx: Context, i = parse_expr(ctx, sym_table, target.slice) if i != ctx.types['int32']: raise CustomError(f'index must be int32', target.slice) + target.type = t.params[0] return t.params[0] elif isinstance(target, ast.Attribute): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if target.attr not in t.fields: raise CustomError(f'{t} has no field {target.attr}', target) + target.type = t.fields[target.attr] return t.fields[target.attr] elif isinstance(target, ast.Name): if target.id not in sym_table: raise CustomError(f'unbounded {target.id}', target) + target.type = sym_table[target.id] return sym_table[target.id] else: raise CustomError(f'assignment to {target} is not supported', target) @@ -76,6 +79,7 @@ def parse_stmt_binding(ctx: Context, f'but is now {ty}', target) if target.id == '_': return {} + target.type = ty return {target.id: ty} elif isinstance(target, ast.Tuple): if not isinstance(ty, TupleType):