nac3-spec/toy-impl/parse_stmt.py

226 lines
9.6 KiB
Python
Raw Permalink Normal View History

2020-12-21 16:58:40 +08:00
import ast
import copy
from helper import *
from type_def import *
from inference import *
2020-12-23 11:22:17 +08:00
from top_level import parse_type
2020-12-22 15:14:34 +08:00
from parse_expr import parse_expr, parse_simple_binding
2020-12-21 16:58:40 +08:00
def parse_stmts(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
2020-12-22 15:14:34 +08:00
return_ty: Type,
2020-12-21 16:58:40 +08:00
nodes):
sym_table2 = copy.copy(sym_table)
used_sym_table2 = copy.copy(used_sym_table)
for node in nodes:
if isinstance(node, ast.Assign):
2020-12-22 15:14:34 +08:00
a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, return_ty, node)
2020-12-21 16:58:40 +08:00
elif isinstance(node, ast.If):
2020-12-22 15:14:34 +08:00
a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
elif isinstance(node, ast.While):
a, b, returned = parse_while_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
elif isinstance(node, ast.For):
a, b, returned = parse_for_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
elif isinstance(node, ast.Return):
a, b, returned = parse_return_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
2020-12-23 16:22:58 +08:00
elif isinstance(node, ast.Expr):
parse_expr(ctx, sym_table2, node.value)
continue
2020-12-22 15:14:34 +08:00
elif isinstance(node, ast.Break) or isinstance(node, ast.Continue):
continue
2020-12-21 16:58:40 +08:00
else:
2020-12-22 16:53:33 +08:00
raise CustomError(f'{node} is not supported yet', node)
2020-12-21 16:58:40 +08:00
sym_table2 |= a
used_sym_table2 |= b
if returned:
return sym_table2, used_sym_table2, True
return sym_table2, used_sym_table2, False
def get_target_type(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
target):
if isinstance(target, ast.Subscript):
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
if not isinstance(t, ListType):
2020-12-22 16:53:33 +08:00
raise CustomError(f'cannot index through type {t}', target)
2020-12-21 16:58:40 +08:00
if isinstance(target.slice, ast.Slice):
2020-12-22 16:53:33 +08:00
raise CustomError(f'assignment to slice is not supported', target)
2020-12-21 16:58:40 +08:00
i = parse_expr(ctx, sym_table, target.slice)
if i != ctx.types['int32']:
2020-12-22 16:53:33 +08:00
raise CustomError(f'index must be int32', target.slice)
2021-01-07 11:57:28 +08:00
target.type = t.params[0]
2020-12-21 16:58:40 +08:00
return t.params[0]
elif isinstance(target, ast.Attribute):
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
if target.attr not in t.fields:
2020-12-22 16:53:33 +08:00
raise CustomError(f'{t} has no field {target.attr}', target)
2021-01-07 11:57:28 +08:00
target.type = t.fields[target.attr]
2020-12-21 16:58:40 +08:00
return t.fields[target.attr]
elif isinstance(target, ast.Name):
if target.id not in sym_table:
2020-12-22 16:53:33 +08:00
raise CustomError(f'unbounded {target.id}', target)
2021-01-07 11:57:28 +08:00
target.type = sym_table[target.id]
2020-12-21 16:58:40 +08:00
return sym_table[target.id]
else:
2020-12-22 16:53:33 +08:00
raise CustomError(f'assignment to {target} is not supported', target)
2020-12-21 16:58:40 +08:00
2020-12-22 15:14:34 +08:00
def parse_stmt_binding(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
target,
ty):
2020-12-21 16:58:40 +08:00
if isinstance(target, ast.Name):
if target.id in used_sym_table:
if used_sym_table[target.id] != ty:
2020-12-22 16:53:33 +08:00
raise CustomError(f'inconsistent type, {target.id} was ' \
f'defined to be {used_sym_table[target.id]}, ' \
f'but is now {ty}', target)
2020-12-21 16:58:40 +08:00
if target.id == '_':
return {}
2021-01-07 11:57:28 +08:00
target.type = ty
2020-12-21 16:58:40 +08:00
return {target.id: ty}
elif isinstance(target, ast.Tuple):
if not isinstance(ty, TupleType):
2020-12-22 16:53:33 +08:00
raise CustomError(f'cannot pattern match over {ty}', target)
2020-12-21 16:58:40 +08:00
if len(target.elts) != len(ty.params):
2020-12-22 16:53:33 +08:00
raise CustomError(f'pattern matching length mismatch', target)
2020-12-21 16:58:40 +08:00
result = {}
for x, y in zip(target.elts, ty.params):
2020-12-22 15:14:34 +08:00
new = parse_stmt_binding(ctx, sym_table, used_sym_table, x, y)
2020-12-21 16:58:40 +08:00
old_len = len(result)
result |= new
used_sym_table |= new
if len(result) != old_len + len(new):
2020-12-22 16:53:33 +08:00
raise CustomError(f'variable name clash', target)
2020-12-21 16:58:40 +08:00
return result
else:
t = get_target_type(ctx, sym_table, used_sym_table, target)
if ty != t:
2020-12-22 16:53:33 +08:00
raise CustomError(f'type mismatch', target)
2020-12-21 16:58:40 +08:00
return {}
def parse_assign(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
2020-12-22 15:14:34 +08:00
return_ty: Type,
2020-12-21 16:58:40 +08:00
node):
# permitted assignment targets:
# variables, class fields, list elements
# function evaluation is only allowed within list index
# may relax later after when we settle on lifetime handling
ty = parse_expr(ctx, sym_table, node.value)
results = {}
for target in node.targets:
2020-12-22 15:14:34 +08:00
results |= parse_stmt_binding(ctx, sym_table, used_sym_table, target, ty)
2020-12-21 16:58:40 +08:00
return results, results, False
def parse_if_stmt(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
2020-12-22 15:14:34 +08:00
return_ty: Type,
2020-12-21 16:58:40 +08:00
node):
2020-12-23 11:22:17 +08:00
if isinstance(node.test, ast.Compare) and \
2020-12-23 11:54:48 +08:00
isinstance(node.test.left, ast.Call) and \
isinstance(node.test.left.func, ast.Name) and \
node.test.left.func.id == 'type' and \
len(node.test.left.args) == 1 and \
2020-12-23 11:22:17 +08:00
len(node.test.ops) == 1 and \
2020-12-23 11:54:48 +08:00
(isinstance(node.test.ops[0], ast.Eq) or\
isinstance(node.test.ops[0], ast.NotEq)):
t = parse_expr(ctx, sym_table, node.test.left.args[0])
if not isinstance(t, TypeVariable):
2020-12-23 11:22:17 +08:00
raise CustomError(
'type guard only support type variables',
2020-12-23 11:22:17 +08:00
node.test)
t1, _ = parse_type(ctx, node.test.comparators[0])
if len(t.constraints) > 0 and t1 not in t.constraints:
2020-12-23 11:22:17 +08:00
raise CustomError(
2020-12-23 16:30:23 +08:00
f'{t1} is not a possible instance of {t}',
2020-12-23 11:22:17 +08:00
node.test)
t2 = [v for v in t.constraints if v != t1]
try:
t.constraints = [t1]
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
used_sym_table |= b
t.constraints = t2
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
defined = {k: a[k] for k in a if k in a1}
return defined, b | b1, r and r1
finally:
t.constraints = t
else:
test = parse_expr(ctx, sym_table, node.test)
if test != ctx.types['bool']:
raise CustomError(f'condition must be bool instead of {test}', node.test)
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
used_sym_table |= b
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
defined = {k: a[k] for k in a if k in a1}
return defined, b | b1, r and r1
2020-12-22 15:14:34 +08:00
def parse_for_stmt(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
return_ty: Type,
node):
ty = parse_expr(ctx, sym_table, node.iter)
2020-12-23 15:39:39 +08:00
if isinstance(ty, TypeVariable) and \
len(ty.constraints) == 1:
ty = ty.constraints[0]
2020-12-22 15:14:34 +08:00
if not isinstance(ty, ListType):
2020-12-23 15:39:39 +08:00
2020-12-22 16:53:33 +08:00
raise CustomError('only iteration over list is supported', node.iter)
2020-12-22 15:14:34 +08:00
binding = parse_simple_binding(node.target, ty.params[0])
for key, value in binding.items():
if key in used_sym_table:
if value != used_sym_table[key]:
2020-12-22 16:53:33 +08:00
raise CustomError('inconsistent type', node)
2020-12-22 15:20:42 +08:00
# more sophisticated return analysis is needed...
a, b, _ = parse_stmts(ctx, sym_table | binding, used_sym_table | binding,
2020-12-22 15:14:34 +08:00
return_ty, node.body)
2020-12-22 15:20:42 +08:00
a1, b1, _ = parse_stmts(ctx, sym_table, used_sym_table | b,
2020-12-22 15:14:34 +08:00
return_ty, node.orelse)
2020-12-21 16:58:40 +08:00
defined = {k: a[k] for k in a if k in a1}
2020-12-22 15:20:42 +08:00
return defined, b | b1, False
2020-12-21 16:58:40 +08:00
2020-12-22 15:14:34 +08:00
def parse_while_stmt(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
return_ty: Type,
node):
ty = parse_expr(ctx, sym_table, node.test)
if ty != ctx.types['bool']:
2020-12-22 16:53:33 +08:00
raise CustomError('condition must be bool', node.test)
2020-12-22 15:14:34 +08:00
# more sophisticated return analysis is needed...
2020-12-22 15:20:42 +08:00
a, b, _ = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
a1, b1, _ = parse_stmts(ctx, sym_table, used_sym_table | b,
2020-12-22 15:14:34 +08:00
return_ty, node.orelse)
defined = {k: a[k] for k in a if k in a1}
2020-12-22 15:20:42 +08:00
return defined, b | b1, False
2020-12-22 15:14:34 +08:00
def parse_return_stmt(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
return_ty: Type,
node):
if return_ty is None:
if node.value is not None:
2020-12-22 16:53:33 +08:00
raise CustomError('no return value is allowed', node)
2020-12-22 15:14:34 +08:00
return {}, {}, True
ty = parse_expr(ctx, sym_table, node.value)
2020-12-23 13:43:34 +08:00
if isinstance(node.value, ast.Name) and \
node.value.id == 'self' and \
'self' in sym_table and \
isinstance(return_ty, SelfType):
return {}, {}, True
2020-12-22 15:14:34 +08:00
if ty != return_ty:
2020-12-23 11:22:17 +08:00
if isinstance(return_ty, TypeVariable):
if len(return_ty.constraints) == 1 and \
return_ty.constraints[0] == ty:
return {}, {}, True
2020-12-22 16:53:33 +08:00
raise CustomError(f'expected returning {return_ty} but got {ty}', node)
2020-12-22 15:14:34 +08:00
return {}, {}, True