nac3-spec/toy-impl/parse_stmt.py

112 lines
4.3 KiB
Python

import ast
import copy
from helper import *
from type_def import *
from inference import *
from parse_expr import parse_expr
def parse_stmts(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
nodes):
sym_table2 = copy.copy(sym_table)
used_sym_table2 = copy.copy(used_sym_table)
for node in nodes:
if isinstance(node, ast.Assign):
a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, node)
elif isinstance(node, ast.If):
a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, node)
else:
raise CustomError(f'{node} is not supported yet')
sym_table2 |= a
used_sym_table2 |= b
if returned:
return sym_table2, used_sym_table2, True
return sym_table2, used_sym_table2, False
def get_target_type(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
target):
if isinstance(target, ast.Subscript):
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
if not isinstance(t, ListType):
raise CustomError(f'cannot index through type {t}')
if isinstance(target.slice, ast.Slice):
raise CustomError(f'assignment to slice is not supported')
i = parse_expr(ctx, sym_table, target.slice)
if i != ctx.types['int32']:
raise CustomError(f'index must be int32')
return t.params[0]
elif isinstance(target, ast.Attribute):
t = get_target_type(ctx, sym_table, used_sym_table, target.value)
if target.attr not in t.fields:
raise CustomError(f'{t} has no field {target.attr}')
return t.fields[target.attr]
elif isinstance(target, ast.Name):
if target.id not in sym_table:
raise CustomError(f'unbounded {target.id}')
return sym_table[target.id]
else:
raise CustomError(f'assignment to {target} is not supported')
def parse_binding(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
target,
ty):
if isinstance(target, ast.Name):
if target.id in used_sym_table:
if used_sym_table[target.id] != ty:
raise CustomError('inconsistent type')
if target.id == '_':
return {}
return {target.id: ty}
elif isinstance(target, ast.Tuple):
if not isinstance(ty, TupleType):
raise CustomError(f'cannot pattern match over {ty}')
if len(target.elts) != len(ty.params):
raise CustomError(f'pattern matching length mismatch')
result = {}
for x, y in zip(target.elts, ty.params):
new = parse_binding(ctx, sym_table, used_sym_table, x, y)
old_len = len(result)
result |= new
used_sym_table |= new
if len(result) != old_len + len(new):
raise CustomError(f'variable name clash')
return result
else:
t = get_target_type(ctx, sym_table, used_sym_table, target)
if ty != t:
raise CustomError(f'inconsistent type')
return {}
def parse_assign(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
node):
# permitted assignment targets:
# variables, class fields, list elements
# function evaluation is only allowed within list index
# may relax later after when we settle on lifetime handling
ty = parse_expr(ctx, sym_table, node.value)
results = {}
for target in node.targets:
results |= parse_binding(ctx, sym_table, used_sym_table, target, ty)
return results, results, False
def parse_if_stmt(ctx: Context,
sym_table: dict[str, Type],
used_sym_table: dict[str, Type],
node):
test = parse_expr(ctx, sym_table, node.test)
if test != ctx.types['bool']:
raise CustomError(f'condition must be bool instead of {test}')
a, b, r = parse_stmts(ctx, sym_table, used_sym_table, node.body)
used_sym_table |= b
a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, node.orelse)
defined = {k: a[k] for k in a if k in a1}
return defined, b | b1, r and r1