From eb2ddfc617e818486a142075bd0dcbaa8cb3dfcd Mon Sep 17 00:00:00 2001 From: pca006132 Date: Mon, 21 Dec 2020 16:58:40 +0800 Subject: [PATCH] basic statements and if statement --- toy-impl/parse_stmt.py | 111 +++++++++++++++++++++++++++++++++++++++++ toy-impl/test_stmt.py | 98 ++++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 toy-impl/parse_stmt.py create mode 100644 toy-impl/test_stmt.py diff --git a/toy-impl/parse_stmt.py b/toy-impl/parse_stmt.py new file mode 100644 index 0000000..f7dccc2 --- /dev/null +++ b/toy-impl/parse_stmt.py @@ -0,0 +1,111 @@ +import ast +import copy +from helper import * +from type_def import * +from inference import * +from parse_expr import parse_expr + +def parse_stmts(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + nodes): + sym_table2 = copy.copy(sym_table) + used_sym_table2 = copy.copy(used_sym_table) + for node in nodes: + if isinstance(node, ast.Assign): + a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, node) + elif isinstance(node, ast.If): + a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, node) + else: + raise CustomError(f'{node} is not supported yet') + sym_table2 |= a + used_sym_table2 |= b + if returned: + return sym_table2, used_sym_table2, True + return sym_table2, used_sym_table2, False + +def get_target_type(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + target): + if isinstance(target, ast.Subscript): + t = get_target_type(ctx, sym_table, used_sym_table, target.value) + if not isinstance(t, ListType): + raise CustomError(f'cannot index through type {t}') + if isinstance(target.slice, ast.Slice): + raise CustomError(f'assignment to slice is not supported') + i = parse_expr(ctx, sym_table, target.slice) + if i != ctx.types['int32']: + raise CustomError(f'index must be int32') + return t.params[0] + elif isinstance(target, ast.Attribute): + t = get_target_type(ctx, sym_table, used_sym_table, target.value) + if target.attr not in t.fields: + raise CustomError(f'{t} has no field {target.attr}') + return t.fields[target.attr] + elif isinstance(target, ast.Name): + if target.id not in sym_table: + raise CustomError(f'unbounded {target.id}') + return sym_table[target.id] + else: + raise CustomError(f'assignment to {target} is not supported') + +def parse_binding(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + target, + ty): + if isinstance(target, ast.Name): + if target.id in used_sym_table: + if used_sym_table[target.id] != ty: + raise CustomError('inconsistent type') + if target.id == '_': + return {} + return {target.id: ty} + elif isinstance(target, ast.Tuple): + if not isinstance(ty, TupleType): + raise CustomError(f'cannot pattern match over {ty}') + if len(target.elts) != len(ty.params): + raise CustomError(f'pattern matching length mismatch') + result = {} + for x, y in zip(target.elts, ty.params): + new = parse_binding(ctx, sym_table, used_sym_table, x, y) + old_len = len(result) + result |= new + used_sym_table |= new + if len(result) != old_len + len(new): + raise CustomError(f'variable name clash') + return result + else: + t = get_target_type(ctx, sym_table, used_sym_table, target) + if ty != t: + raise CustomError(f'inconsistent type') + return {} + +def parse_assign(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + node): + # permitted assignment targets: + # variables, class fields, list elements + # function evaluation is only allowed within list index + # may relax later after when we settle on lifetime handling + ty = parse_expr(ctx, sym_table, node.value) + results = {} + for target in node.targets: + results |= parse_binding(ctx, sym_table, used_sym_table, target, ty) + return results, results, False + +def parse_if_stmt(ctx: Context, + sym_table: dict[str, Type], + used_sym_table: dict[str, Type], + node): + test = parse_expr(ctx, sym_table, node.test) + if test != ctx.types['bool']: + raise CustomError(f'condition must be bool instead of {test}') + a, b, r = parse_stmts(ctx, sym_table, used_sym_table, node.body) + used_sym_table |= b + a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, node.orelse) + defined = {k: a[k] for k in a if k in a1} + return defined, b | b1, r and r1 + diff --git a/toy-impl/test_stmt.py b/toy-impl/test_stmt.py new file mode 100644 index 0000000..2027048 --- /dev/null +++ b/toy-impl/test_stmt.py @@ -0,0 +1,98 @@ +import ast +from type_def import * +from inference import * +from helper import * +from parse_stmt import * + +types = { + 'int32': PrimitiveType('int32'), + 'int64': PrimitiveType('int64'), + 'str': PrimitiveType('str'), + 'bool': PrimitiveType('bool') +} + +i32 = types['int32'] +i64 = types['int64'] +s = types['str'] +b = types['bool'] + +variables = { + 'X': TypeVariable('X', []), + 'Y': TypeVariable('Y', []), + 'I': TypeVariable('I', [i32, i64]), + 'A': TypeVariable('A', [i32, i64, s]), +} + +X = variables['X'] +Y = variables['Y'] +I = variables['I'] +A = variables['A'] + +i32.methods['__init__'] = ([SelfType(), I], None, set()) +i32.methods['__add__'] = ([SelfType(), i32], i32, set()) +i32.methods['__sub__'] = ([SelfType(), i32], i32, set()) +i32.methods['__lt__'] = ([SelfType(), i32], b, set()) +i32.methods['__gt__'] = ([SelfType(), i32], b, set()) +i32.methods['__eq__'] = ([SelfType(), i32], b, set()) +i32.methods['__ne__'] = ([SelfType(), i32], b, set()) +i32.methods['__le__'] = ([SelfType(), i32], b, set()) +i32.methods['__ge__'] = ([SelfType(), i32], b, set()) + +i64.methods['__init__'] = ([SelfType(), I], None, set()) +i64.methods['__add__'] = ([SelfType(), i64], i64, set()) +i64.methods['__sub__'] = ([SelfType(), i64], i64, set()) +i64.methods['__lt__'] = ([SelfType(), i64], b, set()) +i64.methods['__gt__'] = ([SelfType(), i64], b, set()) +i64.methods['__eq__'] = ([SelfType(), i64], b, set()) +i64.methods['__ne__'] = ([SelfType(), i64], b, set()) +i64.methods['__le__'] = ([SelfType(), i64], b, set()) +i64.methods['__ge__'] = ([SelfType(), i64], b, set()) + + +ctx = Context(variables, types) + +def test_stmt(stmt, sym_table = {}): + print(f'Testing {stmt} w.r.t. {stringify_subst(sym_table)}') + try: + tree = ast.parse(stmt) + a, b, _ = parse_stmts(ctx, sym_table, sym_table, tree.body) + print(stringify_subst(a)) + except CustomError as err: + print(f'error: {err.msg}') + +test_stmt('a, b = 1, 2', {}) +test_stmt('a, b = 1, [1, 2, 3]', {}) +test_stmt('a, b[c] = 1, 2', {'b': ListType(i32), 'c': i32}) +test_stmt('a, b[c] = 1, [1, 2]', {'b': ListType(i32), 'c': i32}) +test_stmt('b = [1, 2, 3]\nc = 1\na, b[c] = 1, 2\na = 2') + +test_stmt(""" +if True: + a = 1 +else: + a = 2 + b = 1 +c = a +""") + +test_stmt(""" +if True: + a = 1 +else: + a = 2 + b = 1 +c = a +d = b +""") + +test_stmt(""" +if True: + a = 1 +else: + a = 2 + b = 1 +c = a +b = [1, 2, 3] +""") + +