diff --git a/toy-impl/parse_expr.py b/toy-impl/parse_expr.py index f673bc1..cf885a4 100644 --- a/toy-impl/parse_expr.py +++ b/toy-impl/parse_expr.py @@ -1,4 +1,5 @@ import ast +import copy from helper import * from type_def import * from inference import * @@ -6,8 +7,7 @@ from inference import * # we assume having the following types: # bool, int32 with associated operations -# not handled now: slice, comprehensions, named expression, if expression, type -# guard +# not handled now: named expression, if expression, type guard def parse_expr(ctx: Context, sym_table: dict[str, Type], @@ -40,6 +40,8 @@ def parse_expr(ctx: Context, return parse_subscript(ctx, sym_table, body) if isinstance(body, ast.IfExp): return parse_if_expr(ctx, sym_table, body) + if isinstance(body, ast.ListComp): + return parse_list_comprehension(ctx, sym_table, body) raise CustomError(f'{body} is not yet supported') @@ -204,3 +206,35 @@ def parse_if_expr(ctx: Context, raise CustomError(f'divergent type for if expression: {ty1} != {ty2}') return ty1 +def parse_binding(name, ty): + if isinstance(name, ast.Name): + return {name.id: ty} + elif isinstance(name, ast.Tuple): + if not isinstance(ty, TupleType): + raise CustomError(f'cannot pattern match over {ty}') + if len(name.elts) != len(ty.params): + raise CustomError(f'pattern matching length mismatch') + result = {} + for x, y in zip(name.elts, ty.params): + result |= parse_name(x, y) + return result + else: + raise CustomError(f'binding to {name} is not supported') + +def parse_list_comprehension(ctx: Context, + sym_table: dict[str, Type], + node): + if len(node.generators) != 1: + raise CustomError('list comprehension with more than 1 for loop is not supported') + if node.generators[0].is_async: + raise CustomError('async list comprehension is not supported') + ty = parse_expr(ctx, sym_table, node.generators[0].iter) + if not isinstance(ty, ListType): + raise CustomError(f'unable to iterate over {ty}') + sym_table2 = sym_table | parse_binding(node.generators[0].target, ty.params[0]) + b = ctx.types['bool'] + for c in node.generators[0].ifs: + if parse_expr(ctx, sym_table2, c) != b: + raise CustomError(f'condition should be of boolean type') + return ListType(parse_expr(ctx, sym_table2, node.elt)) + diff --git a/toy-impl/test_expr.py b/toy-impl/test_expr.py index 5794a8e..14155c5 100644 --- a/toy-impl/test_expr.py +++ b/toy-impl/test_expr.py @@ -79,6 +79,9 @@ test_expr('a == a and 1 == 2', {'a': I}) test_expr('1 if a == b else 0', {'a': I, 'b': I}) test_expr('a if a == b else 1', {'a': I, 'b': I}) test_expr('a if a == b else b', {'a': I, 'b': I}) +test_expr('[x for x in [1, 2, 3]]', {}) +test_expr('[1 for x in [1, 2, 3] if x > 2]', {}) +test_expr('[a + a for x in [1, 2, 3] if x > 2]', {'a': I}) test_classes = """ class Foo: