175 lines
7.4 KiB
Python
175 lines
7.4 KiB
Python
import ast
|
|
import copy
|
|
from helper import *
|
|
from type_def import *
|
|
from inference import *
|
|
from parse_expr import parse_expr, parse_simple_binding
|
|
|
|
def parse_stmts(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
return_ty: Type,
|
|
nodes):
|
|
sym_table2 = copy.copy(sym_table)
|
|
used_sym_table2 = copy.copy(used_sym_table)
|
|
for node in nodes:
|
|
if isinstance(node, ast.Assign):
|
|
a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, return_ty, node)
|
|
elif isinstance(node, ast.If):
|
|
a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
|
elif isinstance(node, ast.While):
|
|
a, b, returned = parse_while_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
|
elif isinstance(node, ast.For):
|
|
a, b, returned = parse_for_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
|
elif isinstance(node, ast.Return):
|
|
a, b, returned = parse_return_stmt(ctx, sym_table2, used_sym_table2, return_ty, node)
|
|
elif isinstance(node, ast.Break) or isinstance(node, ast.Continue):
|
|
continue
|
|
else:
|
|
raise CustomError(f'{node} is not supported yet', node)
|
|
sym_table2 |= a
|
|
used_sym_table2 |= b
|
|
if returned:
|
|
return sym_table2, used_sym_table2, True
|
|
return sym_table2, used_sym_table2, False
|
|
|
|
def get_target_type(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
target):
|
|
if isinstance(target, ast.Subscript):
|
|
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
|
|
if not isinstance(t, ListType):
|
|
raise CustomError(f'cannot index through type {t}', target)
|
|
if isinstance(target.slice, ast.Slice):
|
|
raise CustomError(f'assignment to slice is not supported', target)
|
|
i = parse_expr(ctx, sym_table, target.slice)
|
|
if i != ctx.types['int32']:
|
|
raise CustomError(f'index must be int32', target.slice)
|
|
return t.params[0]
|
|
elif isinstance(target, ast.Attribute):
|
|
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
|
|
if target.attr not in t.fields:
|
|
raise CustomError(f'{t} has no field {target.attr}', target)
|
|
return t.fields[target.attr]
|
|
elif isinstance(target, ast.Name):
|
|
if target.id not in sym_table:
|
|
raise CustomError(f'unbounded {target.id}', target)
|
|
return sym_table[target.id]
|
|
else:
|
|
raise CustomError(f'assignment to {target} is not supported', target)
|
|
|
|
def parse_stmt_binding(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
target,
|
|
ty):
|
|
if isinstance(target, ast.Name):
|
|
if target.id in used_sym_table:
|
|
if used_sym_table[target.id] != ty:
|
|
raise CustomError(f'inconsistent type, {target.id} was ' \
|
|
f'defined to be {used_sym_table[target.id]}, ' \
|
|
f'but is now {ty}', target)
|
|
if target.id == '_':
|
|
return {}
|
|
return {target.id: ty}
|
|
elif isinstance(target, ast.Tuple):
|
|
if not isinstance(ty, TupleType):
|
|
raise CustomError(f'cannot pattern match over {ty}', target)
|
|
if len(target.elts) != len(ty.params):
|
|
raise CustomError(f'pattern matching length mismatch', target)
|
|
result = {}
|
|
for x, y in zip(target.elts, ty.params):
|
|
new = parse_stmt_binding(ctx, sym_table, used_sym_table, x, y)
|
|
old_len = len(result)
|
|
result |= new
|
|
used_sym_table |= new
|
|
if len(result) != old_len + len(new):
|
|
raise CustomError(f'variable name clash', target)
|
|
return result
|
|
else:
|
|
t = get_target_type(ctx, sym_table, used_sym_table, target)
|
|
if ty != t:
|
|
raise CustomError(f'type mismatch', target)
|
|
return {}
|
|
|
|
def parse_assign(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
return_ty: Type,
|
|
node):
|
|
# permitted assignment targets:
|
|
# variables, class fields, list elements
|
|
# function evaluation is only allowed within list index
|
|
# may relax later after when we settle on lifetime handling
|
|
ty = parse_expr(ctx, sym_table, node.value)
|
|
results = {}
|
|
for target in node.targets:
|
|
results |= parse_stmt_binding(ctx, sym_table, used_sym_table, target, ty)
|
|
return results, results, False
|
|
|
|
def parse_if_stmt(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
return_ty: Type,
|
|
node):
|
|
test = parse_expr(ctx, sym_table, node.test)
|
|
if test != ctx.types['bool']:
|
|
raise CustomError(f'condition must be bool instead of {test}', node.test)
|
|
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
|
used_sym_table |= b
|
|
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse)
|
|
defined = {k: a[k] for k in a if k in a1}
|
|
return defined, b | b1, r and r1
|
|
|
|
def parse_for_stmt(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
return_ty: Type,
|
|
node):
|
|
ty = parse_expr(ctx, sym_table, node.iter)
|
|
if not isinstance(ty, ListType):
|
|
raise CustomError('only iteration over list is supported', node.iter)
|
|
binding = parse_simple_binding(node.target, ty.params[0])
|
|
for key, value in binding.items():
|
|
if key in used_sym_table:
|
|
if value != used_sym_table[key]:
|
|
raise CustomError('inconsistent type', node)
|
|
# more sophisticated return analysis is needed...
|
|
a, b, _ = parse_stmts(ctx, sym_table | binding, used_sym_table | binding,
|
|
return_ty, node.body)
|
|
a1, b1, _ = parse_stmts(ctx, sym_table, used_sym_table | b,
|
|
return_ty, node.orelse)
|
|
defined = {k: a[k] for k in a if k in a1}
|
|
return defined, b | b1, False
|
|
|
|
def parse_while_stmt(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
return_ty: Type,
|
|
node):
|
|
ty = parse_expr(ctx, sym_table, node.test)
|
|
if ty != ctx.types['bool']:
|
|
raise CustomError('condition must be bool', node.test)
|
|
# more sophisticated return analysis is needed...
|
|
a, b, _ = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body)
|
|
a1, b1, _ = parse_stmts(ctx, sym_table, used_sym_table | b,
|
|
return_ty, node.orelse)
|
|
defined = {k: a[k] for k in a if k in a1}
|
|
return defined, b | b1, False
|
|
|
|
def parse_return_stmt(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
used_sym_table: dict[str, Type],
|
|
return_ty: Type,
|
|
node):
|
|
if return_ty is None:
|
|
if node.value is not None:
|
|
raise CustomError('no return value is allowed', node)
|
|
return {}, {}, True
|
|
ty = parse_expr(ctx, sym_table, node.value)
|
|
if ty != return_ty:
|
|
raise CustomError(f'expected returning {return_ty} but got {ty}', node)
|
|
return {}, {}, True
|
|
|