From 69cf20cb91b79d6ed154a841e9f2834c1a83c630 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Tue, 22 Dec 2020 15:14:34 +0800 Subject: [PATCH] implemented basic statements --- toy-impl/parse_expr.py | 12 +++--- toy-impl/parse_stmt.py | 84 ++++++++++++++++++++++++++++++++++++------ toy-impl/test_stmt.py | 44 ++++++++++++++++++++-- 3 files changed, 118 insertions(+), 22 deletions(-) diff --git a/toy-impl/parse_expr.py b/toy-impl/parse_expr.py index 98e34d8..1015c3a 100644 --- a/toy-impl/parse_expr.py +++ b/toy-impl/parse_expr.py @@ -19,7 +19,7 @@ def parse_expr(ctx: Context, if isinstance(body, ast.Constant): return parse_constant(ctx, sym_table, body) if isinstance(body, ast.UnaryOp): - return parse_unary_op(ctx, sym_table, body) + return parse_unary_ops(ctx, sym_table, body) if isinstance(body, ast.BinOp): return parse_bin_ops(ctx, sym_table, body) if isinstance(body, ast.Name): @@ -131,7 +131,7 @@ def parse_bin_ops(ctx: Context, def parse_unary_ops(ctx: Context, sym_table: dict[str, Type], node): - t = parse_expr(node.operand) + t = parse_expr(ctx, sym_table, node.operand) if isinstance(node.op, ast.Not): b = ctx.types['bool'] if t != b: @@ -206,7 +206,7 @@ def parse_if_expr(ctx: Context, raise CustomError(f'divergent type for if expression: {ty1} != {ty2}') return ty1 -def parse_binding(name, ty): +def parse_simple_binding(name, ty): if isinstance(name, ast.Name): if name.id == '_': return {} @@ -218,9 +218,9 @@ def parse_binding(name, ty): raise CustomError(f'pattern matching length mismatch') result = {} for x, y in zip(name.elts, ty.params): - binding = parse_binding(x, y) + binding = parse_simple_binding(x, y) expected = len(result) + len(binding) - result |= parse_binding(x, y) + result |= parse_simple_binding(x, y) if len(result) != expected: raise CustomError('variable name clash') return result @@ -237,7 +237,7 @@ def parse_list_comprehension(ctx: Context, ty = parse_expr(ctx, sym_table, node.generators[0].iter) if not isinstance(ty, ListType): raise CustomError(f'unable to iterate over {ty}') - sym_table2 = sym_table | parse_binding(node.generators[0].target, ty.params[0]) + sym_table2 = sym_table | parse_simple_binding(node.generators[0].target, ty.params[0]) b = ctx.types['bool'] for c in node.generators[0].ifs: if parse_expr(ctx, sym_table2, c) != b: diff --git a/toy-impl/parse_stmt.py b/toy-impl/parse_stmt.py index f7dccc2..5b84044 100644 --- a/toy-impl/parse_stmt.py +++ b/toy-impl/parse_stmt.py @@ -3,19 +3,28 @@ import copy from helper import * from type_def import * from inference import * -from parse_expr import parse_expr +from parse_expr import parse_expr, parse_simple_binding def parse_stmts(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], + return_ty: Type, nodes): sym_table2 = copy.copy(sym_table) used_sym_table2 = copy.copy(used_sym_table) for node in nodes: if isinstance(node, ast.Assign): - a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, node) + a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.If): - a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, node) + a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) + elif isinstance(node, ast.While): + a, b, returned = parse_while_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) + elif isinstance(node, ast.For): + a, b, returned = parse_for_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) + elif isinstance(node, ast.Return): + a, b, returned = parse_return_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) + elif isinstance(node, ast.Break) or isinstance(node, ast.Continue): + continue else: raise CustomError(f'{node} is not supported yet') sym_table2 |= a @@ -50,11 +59,11 @@ def get_target_type(ctx: Context, else: raise CustomError(f'assignment to {target} is not supported') -def parse_binding(ctx: Context, - sym_table: dict[str, Type], - used_sym_table: dict[str, Type], - target, - ty): +def parse_stmt_binding(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + target, + ty): if isinstance(target, ast.Name): if target.id in used_sym_table: if used_sym_table[target.id] != ty: @@ -69,7 +78,7 @@ def parse_binding(ctx: Context, raise CustomError(f'pattern matching length mismatch') result = {} for x, y in zip(target.elts, ty.params): - new = parse_binding(ctx, sym_table, used_sym_table, x, y) + new = parse_stmt_binding(ctx, sym_table, used_sym_table, x, y) old_len = len(result) result |= new used_sym_table |= new @@ -85,6 +94,7 @@ def parse_binding(ctx: Context, def parse_assign(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], + return_ty: Type, node): # permitted assignment targets: # variables, class fields, list elements @@ -93,19 +103,69 @@ def parse_assign(ctx: Context, ty = parse_expr(ctx, sym_table, node.value) results = {} for target in node.targets: - results |= parse_binding(ctx, sym_table, used_sym_table, target, ty) + results |= parse_stmt_binding(ctx, sym_table, used_sym_table, target, ty) return results, results, False def parse_if_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], + return_ty: Type, node): test = parse_expr(ctx, sym_table, node.test) if test != ctx.types['bool']: raise CustomError(f'condition must be bool instead of {test}') - a, b, r = parse_stmts(ctx, sym_table, used_sym_table, node.body) + a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) used_sym_table |= b - a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, node.orelse) + a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1 +def parse_for_stmt(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + return_ty: Type, + node): + ty = parse_expr(ctx, sym_table, node.iter) + if not isinstance(ty, ListType): + raise CustomError('only iteration over list is supported') + binding = parse_simple_binding(node.target, ty.params[0]) + for key, value in binding.items(): + if key in used_sym_table: + if value != used_sym_table[key]: + raise CustomError('inconsistent type') + a, b, r = parse_stmts(ctx, sym_table | binding, used_sym_table | binding, + return_ty, node.body) + a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table | b, + return_ty, node.orelse) + defined = {k: a[k] for k in a if k in a1} + return defined, b | b1, r and r1 + +def parse_while_stmt(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + return_ty: Type, + node): + ty = parse_expr(ctx, sym_table, node.test) + if ty != ctx.types['bool']: + raise CustomError('condition must be bool') + # more sophisticated return analysis is needed... + a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) + a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table | b, + return_ty, node.orelse) + defined = {k: a[k] for k in a if k in a1} + return defined, b | b1, r and r1 + +def parse_return_stmt(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + return_ty: Type, + node): + if return_ty is None: + if node.value is not None: + raise CustomError('no return value is allowed') + return {}, {}, True + ty = parse_expr(ctx, sym_table, node.value) + if ty != return_ty: + raise CustomError(f'expected returning {return_ty} but got {ty}') + return {}, {}, True + diff --git a/toy-impl/test_stmt.py b/toy-impl/test_stmt.py index 2027048..9534641 100644 --- a/toy-impl/test_stmt.py +++ b/toy-impl/test_stmt.py @@ -31,6 +31,7 @@ A = variables['A'] i32.methods['__init__'] = ([SelfType(), I], None, set()) i32.methods['__add__'] = ([SelfType(), i32], i32, set()) i32.methods['__sub__'] = ([SelfType(), i32], i32, set()) +i32.methods['__neg__'] = ([SelfType()], i32, set()) i32.methods['__lt__'] = ([SelfType(), i32], b, set()) i32.methods['__gt__'] = ([SelfType(), i32], b, set()) i32.methods['__eq__'] = ([SelfType(), i32], b, set()) @@ -41,6 +42,7 @@ i32.methods['__ge__'] = ([SelfType(), i32], b, set()) i64.methods['__init__'] = ([SelfType(), I], None, set()) i64.methods['__add__'] = ([SelfType(), i64], i64, set()) i64.methods['__sub__'] = ([SelfType(), i64], i64, set()) +i64.methods['__neg__'] = ([SelfType()], i64, set()) i64.methods['__lt__'] = ([SelfType(), i64], b, set()) i64.methods['__gt__'] = ([SelfType(), i64], b, set()) i64.methods['__eq__'] = ([SelfType(), i64], b, set()) @@ -51,14 +53,17 @@ i64.methods['__ge__'] = ([SelfType(), i64], b, set()) ctx = Context(variables, types) -def test_stmt(stmt, sym_table = {}): - print(f'Testing {stmt} w.r.t. {stringify_subst(sym_table)}') +def test_stmt(stmt, sym_table = {}, return_ty = None): + print(f'Testing:\n{stmt}\n\nw.r.t. {stringify_subst(sym_table)}') try: tree = ast.parse(stmt) - a, b, _ = parse_stmts(ctx, sym_table, sym_table, tree.body) - print(stringify_subst(a)) + a, b, returned = parse_stmts(ctx, sym_table, sym_table, return_ty, tree.body) + print(f'defined variables: {stringify_subst(a)}') + print(f'returned: {returned}') + print('---') except CustomError as err: print(f'error: {err.msg}') + print('---') test_stmt('a, b = 1, 2', {}) test_stmt('a, b = 1, [1, 2, 3]', {}) @@ -95,4 +100,35 @@ c = a b = [1, 2, 3] """) +test_stmt(""" +c = 0 +for i in [1, 2, 3]: + c = c + i +""") + +test_stmt(""" +c = 0 +for i in [1, 2, 3]: + c = c + i +if c > 0: + return c +else: + return -c +""", {}, i32) + +test_stmt(""" +c = 0 +for i in [1, 2, 3]: + c = c + i +if c > 0: + return c +""", {}, i32) + +test_stmt(""" +c = i = 0 +for i in [True, True, False]: + c = c + 1 +if c > 0: + return c +""", {}, i32)