implemented basic statements
This commit is contained in:
parent
eb2ddfc617
commit
69cf20cb91
|
@ -19,7 +19,7 @@ def parse_expr(ctx: Context,
|
||||||
if isinstance(body, ast.Constant):
|
if isinstance(body, ast.Constant):
|
||||||
return parse_constant(ctx, sym_table, body)
|
return parse_constant(ctx, sym_table, body)
|
||||||
if isinstance(body, ast.UnaryOp):
|
if isinstance(body, ast.UnaryOp):
|
||||||
return parse_unary_op(ctx, sym_table, body)
|
return parse_unary_ops(ctx, sym_table, body)
|
||||||
if isinstance(body, ast.BinOp):
|
if isinstance(body, ast.BinOp):
|
||||||
return parse_bin_ops(ctx, sym_table, body)
|
return parse_bin_ops(ctx, sym_table, body)
|
||||||
if isinstance(body, ast.Name):
|
if isinstance(body, ast.Name):
|
||||||
|
@ -131,7 +131,7 @@ def parse_bin_ops(ctx: Context,
|
||||||
def parse_unary_ops(ctx: Context,
|
def parse_unary_ops(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
node):
|
node):
|
||||||
t = parse_expr(node.operand)
|
t = parse_expr(ctx, sym_table, node.operand)
|
||||||
if isinstance(node.op, ast.Not):
|
if isinstance(node.op, ast.Not):
|
||||||
b = ctx.types['bool']
|
b = ctx.types['bool']
|
||||||
if t != b:
|
if t != b:
|
||||||
|
@ -206,7 +206,7 @@ def parse_if_expr(ctx: Context,
|
||||||
raise CustomError(f'divergent type for if expression: {ty1} != {ty2}')
|
raise CustomError(f'divergent type for if expression: {ty1} != {ty2}')
|
||||||
return ty1
|
return ty1
|
||||||
|
|
||||||
def parse_binding(name, ty):
|
def parse_simple_binding(name, ty):
|
||||||
if isinstance(name, ast.Name):
|
if isinstance(name, ast.Name):
|
||||||
if name.id == '_':
|
if name.id == '_':
|
||||||
return {}
|
return {}
|
||||||
|
@ -218,9 +218,9 @@ def parse_binding(name, ty):
|
||||||
raise CustomError(f'pattern matching length mismatch')
|
raise CustomError(f'pattern matching length mismatch')
|
||||||
result = {}
|
result = {}
|
||||||
for x, y in zip(name.elts, ty.params):
|
for x, y in zip(name.elts, ty.params):
|
||||||
binding = parse_binding(x, y)
|
binding = parse_simple_binding(x, y)
|
||||||
expected = len(result) + len(binding)
|
expected = len(result) + len(binding)
|
||||||
result |= parse_binding(x, y)
|
result |= parse_simple_binding(x, y)
|
||||||
if len(result) != expected:
|
if len(result) != expected:
|
||||||
raise CustomError('variable name clash')
|
raise CustomError('variable name clash')
|
||||||
return result
|
return result
|
||||||
|
@ -237,7 +237,7 @@ def parse_list_comprehension(ctx: Context,
|
||||||
ty = parse_expr(ctx, sym_table, node.generators[0].iter)
|
ty = parse_expr(ctx, sym_table, node.generators[0].iter)
|
||||||
if not isinstance(ty, ListType):
|
if not isinstance(ty, ListType):
|
||||||
raise CustomError(f'unable to iterate over {ty}')
|
raise CustomError(f'unable to iterate over {ty}')
|
||||||
sym_table2 = sym_table | parse_binding(node.generators[0].target, ty.params[0])
|
sym_table2 = sym_table | parse_simple_binding(node.generators[0].target, ty.params[0])
|
||||||
b = ctx.types['bool']
|
b = ctx.types['bool']
|
||||||
for c in node.generators[0].ifs:
|
for c in node.generators[0].ifs:
|
||||||
if parse_expr(ctx, sym_table2, c) != b:
|
if parse_expr(ctx, sym_table2, c) != b:
|
||||||
|
|
|
@ -3,19 +3,28 @@ import copy
|
||||||
from helper import *
|
from helper import *
|
||||||
from type_def import *
|
from type_def import *
|
||||||
from inference import *
|
from inference import *
|
||||||
from parse_expr import parse_expr
|
from parse_expr import parse_expr, parse_simple_binding
|
||||||
|
|
||||||
def parse_stmts(ctx: Context,
|
def parse_stmts(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
used_sym_table: dict[str, Type],
|
used_sym_table: dict[str, Type],
|
||||||
|
return_ty: Type,
|
||||||
nodes):
|
nodes):
|
||||||
sym_table2 = copy.copy(sym_table)
|
sym_table2 = copy.copy(sym_table)
|
||||||
used_sym_table2 = copy.copy(used_sym_table)
|
used_sym_table2 = copy.copy(used_sym_table)
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if isinstance(node, ast.Assign):
|
if isinstance(node, ast.Assign):
|
||||||
a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, node)
|
a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, return_ty, node)
|
||||||
elif isinstance(node, ast.If):
|
elif isinstance(node, ast.If):
|
||||||
a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, node)
|
a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
||||||
|
elif isinstance(node, ast.While):
|
||||||
|
a, b, returned = parse_while_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
||||||
|
elif isinstance(node, ast.For):
|
||||||
|
a, b, returned = parse_for_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
||||||
|
elif isinstance(node, ast.Return):
|
||||||
|
a, b, returned = parse_return_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
||||||
|
elif isinstance(node, ast.Break) or isinstance(node, ast.Continue):
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
raise CustomError(f'{node} is not supported yet')
|
raise CustomError(f'{node} is not supported yet')
|
||||||
sym_table2 |= a
|
sym_table2 |= a
|
||||||
|
@ -50,11 +59,11 @@ def get_target_type(ctx: Context,
|
||||||
else:
|
else:
|
||||||
raise CustomError(f'assignment to {target} is not supported')
|
raise CustomError(f'assignment to {target} is not supported')
|
||||||
|
|
||||||
def parse_binding(ctx: Context,
|
def parse_stmt_binding(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
used_sym_table: dict[str, Type],
|
used_sym_table: dict[str, Type],
|
||||||
target,
|
target,
|
||||||
ty):
|
ty):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in used_sym_table:
|
if target.id in used_sym_table:
|
||||||
if used_sym_table[target.id] != ty:
|
if used_sym_table[target.id] != ty:
|
||||||
|
@ -69,7 +78,7 @@ def parse_binding(ctx: Context,
|
||||||
raise CustomError(f'pattern matching length mismatch')
|
raise CustomError(f'pattern matching length mismatch')
|
||||||
result = {}
|
result = {}
|
||||||
for x, y in zip(target.elts, ty.params):
|
for x, y in zip(target.elts, ty.params):
|
||||||
new = parse_binding(ctx, sym_table, used_sym_table, x, y)
|
new = parse_stmt_binding(ctx, sym_table, used_sym_table, x, y)
|
||||||
old_len = len(result)
|
old_len = len(result)
|
||||||
result |= new
|
result |= new
|
||||||
used_sym_table |= new
|
used_sym_table |= new
|
||||||
|
@ -85,6 +94,7 @@ def parse_binding(ctx: Context,
|
||||||
def parse_assign(ctx: Context,
|
def parse_assign(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
used_sym_table: dict[str, Type],
|
used_sym_table: dict[str, Type],
|
||||||
|
return_ty: Type,
|
||||||
node):
|
node):
|
||||||
# permitted assignment targets:
|
# permitted assignment targets:
|
||||||
# variables, class fields, list elements
|
# variables, class fields, list elements
|
||||||
|
@ -93,19 +103,69 @@ def parse_assign(ctx: Context,
|
||||||
ty = parse_expr(ctx, sym_table, node.value)
|
ty = parse_expr(ctx, sym_table, node.value)
|
||||||
results = {}
|
results = {}
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
results |= parse_binding(ctx, sym_table, used_sym_table, target, ty)
|
results |= parse_stmt_binding(ctx, sym_table, used_sym_table, target, ty)
|
||||||
return results, results, False
|
return results, results, False
|
||||||
|
|
||||||
def parse_if_stmt(ctx: Context,
|
def parse_if_stmt(ctx: Context,
|
||||||
sym_table: dict[str, Type],
|
sym_table: dict[str, Type],
|
||||||
used_sym_table: dict[str, Type],
|
used_sym_table: dict[str, Type],
|
||||||
|
return_ty: Type,
|
||||||
node):
|
node):
|
||||||
test = parse_expr(ctx, sym_table, node.test)
|
test = parse_expr(ctx, sym_table, node.test)
|
||||||
if test != ctx.types['bool']:
|
if test != ctx.types['bool']:
|
||||||
raise CustomError(f'condition must be bool instead of {test}')
|
raise CustomError(f'condition must be bool instead of {test}')
|
||||||
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, node.body)
|
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
||||||
used_sym_table |= b
|
used_sym_table |= b
|
||||||
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, node.orelse)
|
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
|
||||||
defined = {k: a[k] for k in a if k in a1}
|
defined = {k: a[k] for k in a if k in a1}
|
||||||
return defined, b | b1, r and r1
|
return defined, b | b1, r and r1
|
||||||
|
|
||||||
|
def parse_for_stmt(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
used_sym_table: dict[str, Type],
|
||||||
|
return_ty: Type,
|
||||||
|
node):
|
||||||
|
ty = parse_expr(ctx, sym_table, node.iter)
|
||||||
|
if not isinstance(ty, ListType):
|
||||||
|
raise CustomError('only iteration over list is supported')
|
||||||
|
binding = parse_simple_binding(node.target, ty.params[0])
|
||||||
|
for key, value in binding.items():
|
||||||
|
if key in used_sym_table:
|
||||||
|
if value != used_sym_table[key]:
|
||||||
|
raise CustomError('inconsistent type')
|
||||||
|
a, b, r = parse_stmts(ctx, sym_table | binding, used_sym_table | binding,
|
||||||
|
return_ty, node.body)
|
||||||
|
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table | b,
|
||||||
|
return_ty, node.orelse)
|
||||||
|
defined = {k: a[k] for k in a if k in a1}
|
||||||
|
return defined, b | b1, r and r1
|
||||||
|
|
||||||
|
def parse_while_stmt(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
used_sym_table: dict[str, Type],
|
||||||
|
return_ty: Type,
|
||||||
|
node):
|
||||||
|
ty = parse_expr(ctx, sym_table, node.test)
|
||||||
|
if ty != ctx.types['bool']:
|
||||||
|
raise CustomError('condition must be bool')
|
||||||
|
# more sophisticated return analysis is needed...
|
||||||
|
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
||||||
|
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table | b,
|
||||||
|
return_ty, node.orelse)
|
||||||
|
defined = {k: a[k] for k in a if k in a1}
|
||||||
|
return defined, b | b1, r and r1
|
||||||
|
|
||||||
|
def parse_return_stmt(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
used_sym_table: dict[str, Type],
|
||||||
|
return_ty: Type,
|
||||||
|
node):
|
||||||
|
if return_ty is None:
|
||||||
|
if node.value is not None:
|
||||||
|
raise CustomError('no return value is allowed')
|
||||||
|
return {}, {}, True
|
||||||
|
ty = parse_expr(ctx, sym_table, node.value)
|
||||||
|
if ty != return_ty:
|
||||||
|
raise CustomError(f'expected returning {return_ty} but got {ty}')
|
||||||
|
return {}, {}, True
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ A = variables['A']
|
||||||
i32.methods['__init__'] = ([SelfType(), I], None, set())
|
i32.methods['__init__'] = ([SelfType(), I], None, set())
|
||||||
i32.methods['__add__'] = ([SelfType(), i32], i32, set())
|
i32.methods['__add__'] = ([SelfType(), i32], i32, set())
|
||||||
i32.methods['__sub__'] = ([SelfType(), i32], i32, set())
|
i32.methods['__sub__'] = ([SelfType(), i32], i32, set())
|
||||||
|
i32.methods['__neg__'] = ([SelfType()], i32, set())
|
||||||
i32.methods['__lt__'] = ([SelfType(), i32], b, set())
|
i32.methods['__lt__'] = ([SelfType(), i32], b, set())
|
||||||
i32.methods['__gt__'] = ([SelfType(), i32], b, set())
|
i32.methods['__gt__'] = ([SelfType(), i32], b, set())
|
||||||
i32.methods['__eq__'] = ([SelfType(), i32], b, set())
|
i32.methods['__eq__'] = ([SelfType(), i32], b, set())
|
||||||
|
@ -41,6 +42,7 @@ i32.methods['__ge__'] = ([SelfType(), i32], b, set())
|
||||||
i64.methods['__init__'] = ([SelfType(), I], None, set())
|
i64.methods['__init__'] = ([SelfType(), I], None, set())
|
||||||
i64.methods['__add__'] = ([SelfType(), i64], i64, set())
|
i64.methods['__add__'] = ([SelfType(), i64], i64, set())
|
||||||
i64.methods['__sub__'] = ([SelfType(), i64], i64, set())
|
i64.methods['__sub__'] = ([SelfType(), i64], i64, set())
|
||||||
|
i64.methods['__neg__'] = ([SelfType()], i64, set())
|
||||||
i64.methods['__lt__'] = ([SelfType(), i64], b, set())
|
i64.methods['__lt__'] = ([SelfType(), i64], b, set())
|
||||||
i64.methods['__gt__'] = ([SelfType(), i64], b, set())
|
i64.methods['__gt__'] = ([SelfType(), i64], b, set())
|
||||||
i64.methods['__eq__'] = ([SelfType(), i64], b, set())
|
i64.methods['__eq__'] = ([SelfType(), i64], b, set())
|
||||||
|
@ -51,14 +53,17 @@ i64.methods['__ge__'] = ([SelfType(), i64], b, set())
|
||||||
|
|
||||||
ctx = Context(variables, types)
|
ctx = Context(variables, types)
|
||||||
|
|
||||||
def test_stmt(stmt, sym_table = {}):
|
def test_stmt(stmt, sym_table = {}, return_ty = None):
|
||||||
print(f'Testing {stmt} w.r.t. {stringify_subst(sym_table)}')
|
print(f'Testing:\n{stmt}\n\nw.r.t. {stringify_subst(sym_table)}')
|
||||||
try:
|
try:
|
||||||
tree = ast.parse(stmt)
|
tree = ast.parse(stmt)
|
||||||
a, b, _ = parse_stmts(ctx, sym_table, sym_table, tree.body)
|
a, b, returned = parse_stmts(ctx, sym_table, sym_table, return_ty, tree.body)
|
||||||
print(stringify_subst(a))
|
print(f'defined variables: {stringify_subst(a)}')
|
||||||
|
print(f'returned: {returned}')
|
||||||
|
print('---')
|
||||||
except CustomError as err:
|
except CustomError as err:
|
||||||
print(f'error: {err.msg}')
|
print(f'error: {err.msg}')
|
||||||
|
print('---')
|
||||||
|
|
||||||
test_stmt('a, b = 1, 2', {})
|
test_stmt('a, b = 1, 2', {})
|
||||||
test_stmt('a, b = 1, [1, 2, 3]', {})
|
test_stmt('a, b = 1, [1, 2, 3]', {})
|
||||||
|
@ -95,4 +100,35 @@ c = a
|
||||||
b = [1, 2, 3]
|
b = [1, 2, 3]
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
test_stmt("""
|
||||||
|
c = 0
|
||||||
|
for i in [1, 2, 3]:
|
||||||
|
c = c + i
|
||||||
|
""")
|
||||||
|
|
||||||
|
test_stmt("""
|
||||||
|
c = 0
|
||||||
|
for i in [1, 2, 3]:
|
||||||
|
c = c + i
|
||||||
|
if c > 0:
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return -c
|
||||||
|
""", {}, i32)
|
||||||
|
|
||||||
|
test_stmt("""
|
||||||
|
c = 0
|
||||||
|
for i in [1, 2, 3]:
|
||||||
|
c = c + i
|
||||||
|
if c > 0:
|
||||||
|
return c
|
||||||
|
""", {}, i32)
|
||||||
|
|
||||||
|
test_stmt("""
|
||||||
|
c = i = 0
|
||||||
|
for i in [True, True, False]:
|
||||||
|
c = c + 1
|
||||||
|
if c > 0:
|
||||||
|
return c
|
||||||
|
""", {}, i32)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue