list comprehension
This commit is contained in:
parent
5d679d88b5
commit
dd02c795c7
@ -1,4 +1,5 @@
|
||||
import ast
|
||||
import copy
|
||||
from helper import *
|
||||
from type_def import *
|
||||
from inference import *
|
||||
@ -6,8 +7,7 @@ from inference import *
|
||||
# we assume having the following types:
|
||||
# bool, int32 with associated operations
|
||||
|
||||
# not handled now: slice, comprehensions, named expression, if expression, type
|
||||
# guard
|
||||
# not handled now: named expression, if expression, type guard
|
||||
|
||||
def parse_expr(ctx: Context,
|
||||
sym_table: dict[str, Type],
|
||||
@ -40,6 +40,8 @@ def parse_expr(ctx: Context,
|
||||
return parse_subscript(ctx, sym_table, body)
|
||||
if isinstance(body, ast.IfExp):
|
||||
return parse_if_expr(ctx, sym_table, body)
|
||||
if isinstance(body, ast.ListComp):
|
||||
return parse_list_comprehension(ctx, sym_table, body)
|
||||
raise CustomError(f'{body} is not yet supported')
|
||||
|
||||
|
||||
@ -204,3 +206,35 @@ def parse_if_expr(ctx: Context,
|
||||
raise CustomError(f'divergent type for if expression: {ty1} != {ty2}')
|
||||
return ty1
|
||||
|
||||
def parse_binding(name, ty):
|
||||
if isinstance(name, ast.Name):
|
||||
return {name.id: ty}
|
||||
elif isinstance(name, ast.Tuple):
|
||||
if not isinstance(ty, TupleType):
|
||||
raise CustomError(f'cannot pattern match over {ty}')
|
||||
if len(name.elts) != len(ty.params):
|
||||
raise CustomError(f'pattern matching length mismatch')
|
||||
result = {}
|
||||
for x, y in zip(name.elts, ty.params):
|
||||
result |= parse_name(x, y)
|
||||
return result
|
||||
else:
|
||||
raise CustomError(f'binding to {name} is not supported')
|
||||
|
||||
def parse_list_comprehension(ctx: Context,
|
||||
sym_table: dict[str, Type],
|
||||
node):
|
||||
if len(node.generators) != 1:
|
||||
raise CustomError('list comprehension with more than 1 for loop is not supported')
|
||||
if node.generators[0].is_async:
|
||||
raise CustomError('async list comprehension is not supported')
|
||||
ty = parse_expr(ctx, sym_table, node.generators[0].iter)
|
||||
if not isinstance(ty, ListType):
|
||||
raise CustomError(f'unable to iterate over {ty}')
|
||||
sym_table2 = sym_table | parse_binding(node.generators[0].target, ty.params[0])
|
||||
b = ctx.types['bool']
|
||||
for c in node.generators[0].ifs:
|
||||
if parse_expr(ctx, sym_table2, c) != b:
|
||||
raise CustomError(f'condition should be of boolean type')
|
||||
return ListType(parse_expr(ctx, sym_table2, node.elt))
|
||||
|
||||
|
@ -79,6 +79,9 @@ test_expr('a == a and 1 == 2', {'a': I})
|
||||
test_expr('1 if a == b else 0', {'a': I, 'b': I})
|
||||
test_expr('a if a == b else 1', {'a': I, 'b': I})
|
||||
test_expr('a if a == b else b', {'a': I, 'b': I})
|
||||
test_expr('[x for x in [1, 2, 3]]', {})
|
||||
test_expr('[1 for x in [1, 2, 3] if x > 2]', {})
|
||||
test_expr('[a + a for x in [1, 2, 3] if x > 2]', {'a': I})
|
||||
|
||||
test_classes = """
|
||||
class Foo:
|
||||
|
Loading…
Reference in New Issue
Block a user