list comprehension

This commit is contained in:
pca006132 2020-12-21 11:50:20 +08:00 committed by pca006132
parent 5d679d88b5
commit dd02c795c7
2 changed files with 39 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import ast import ast
import copy
from helper import * from helper import *
from type_def import * from type_def import *
from inference import * from inference import *
@ -6,8 +7,7 @@ from inference import *
# we assume having the following types: # we assume having the following types:
# bool, int32 with associated operations # bool, int32 with associated operations
# not handled now: slice, comprehensions, named expression, if expression, type # not handled now: named expression, if expression, type guard
# guard
def parse_expr(ctx: Context, def parse_expr(ctx: Context,
sym_table: dict[str, Type], sym_table: dict[str, Type],
@ -40,6 +40,8 @@ def parse_expr(ctx: Context,
return parse_subscript(ctx, sym_table, body) return parse_subscript(ctx, sym_table, body)
if isinstance(body, ast.IfExp): if isinstance(body, ast.IfExp):
return parse_if_expr(ctx, sym_table, body) return parse_if_expr(ctx, sym_table, body)
if isinstance(body, ast.ListComp):
return parse_list_comprehension(ctx, sym_table, body)
raise CustomError(f'{body} is not yet supported') raise CustomError(f'{body} is not yet supported')
@ -204,3 +206,35 @@ def parse_if_expr(ctx: Context,
raise CustomError(f'divergent type for if expression: {ty1} != {ty2}') raise CustomError(f'divergent type for if expression: {ty1} != {ty2}')
return ty1 return ty1
def parse_binding(name, ty):
if isinstance(name, ast.Name):
return {name.id: ty}
elif isinstance(name, ast.Tuple):
if not isinstance(ty, TupleType):
raise CustomError(f'cannot pattern match over {ty}')
if len(name.elts) != len(ty.params):
raise CustomError(f'pattern matching length mismatch')
result = {}
for x, y in zip(name.elts, ty.params):
result |= parse_name(x, y)
return result
else:
raise CustomError(f'binding to {name} is not supported')
def parse_list_comprehension(ctx: Context,
sym_table: dict[str, Type],
node):
if len(node.generators) != 1:
raise CustomError('list comprehension with more than 1 for loop is not supported')
if node.generators[0].is_async:
raise CustomError('async list comprehension is not supported')
ty = parse_expr(ctx, sym_table, node.generators[0].iter)
if not isinstance(ty, ListType):
raise CustomError(f'unable to iterate over {ty}')
sym_table2 = sym_table | parse_binding(node.generators[0].target, ty.params[0])
b = ctx.types['bool']
for c in node.generators[0].ifs:
if parse_expr(ctx, sym_table2, c) != b:
raise CustomError(f'condition should be of boolean type')
return ListType(parse_expr(ctx, sym_table2, node.elt))

View File

@ -79,6 +79,9 @@ test_expr('a == a and 1 == 2', {'a': I})
test_expr('1 if a == b else 0', {'a': I, 'b': I}) test_expr('1 if a == b else 0', {'a': I, 'b': I})
test_expr('a if a == b else 1', {'a': I, 'b': I}) test_expr('a if a == b else 1', {'a': I, 'b': I})
test_expr('a if a == b else b', {'a': I, 'b': I}) test_expr('a if a == b else b', {'a': I, 'b': I})
test_expr('[x for x in [1, 2, 3]]', {})
test_expr('[1 for x in [1, 2, 3] if x > 2]', {})
test_expr('[a + a for x in [1, 2, 3] if x > 2]', {'a': I})
test_classes = """ test_classes = """
class Foo: class Foo: