import ast import copy from helper import * from type_def import * from inference import * from parse_expr import parse_expr def parse_stmts(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], nodes): sym_table2 = copy.copy(sym_table) used_sym_table2 = copy.copy(used_sym_table) for node in nodes: if isinstance(node, ast.Assign): a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, node) elif isinstance(node, ast.If): a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, node) else: raise CustomError(f'{node} is not supported yet') sym_table2 |= a used_sym_table2 |= b if returned: return sym_table2, used_sym_table2, True return sym_table2, used_sym_table2, False def get_target_type(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], target): if isinstance(target, ast.Subscript): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if not isinstance(t, ListType): raise CustomError(f'cannot index through type {t}') if isinstance(target.slice, ast.Slice): raise CustomError(f'assignment to slice is not supported') i = parse_expr(ctx, sym_table, target.slice) if i != ctx.types['int32']: raise CustomError(f'index must be int32') return t.params[0] elif isinstance(target, ast.Attribute): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if target.attr not in t.fields: raise CustomError(f'{t} has no field {target.attr}') return t.fields[target.attr] elif isinstance(target, ast.Name): if target.id not in sym_table: raise CustomError(f'unbounded {target.id}') return sym_table[target.id] else: raise CustomError(f'assignment to {target} is not supported') def parse_binding(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], target, ty): if isinstance(target, ast.Name): if target.id in used_sym_table: if used_sym_table[target.id] != ty: raise CustomError('inconsistent type') if target.id == '_': return {} return {target.id: ty} elif isinstance(target, ast.Tuple): if not isinstance(ty, TupleType): raise CustomError(f'cannot pattern match over {ty}') if len(target.elts) != len(ty.params): raise CustomError(f'pattern matching length mismatch') result = {} for x, y in zip(target.elts, ty.params): new = parse_binding(ctx, sym_table, used_sym_table, x, y) old_len = len(result) result |= new used_sym_table |= new if len(result) != old_len + len(new): raise CustomError(f'variable name clash') return result else: t = get_target_type(ctx, sym_table, used_sym_table, target) if ty != t: raise CustomError(f'inconsistent type') return {} def parse_assign(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], node): # permitted assignment targets: # variables, class fields, list elements # function evaluation is only allowed within list index # may relax later after when we settle on lifetime handling ty = parse_expr(ctx, sym_table, node.value) results = {} for target in node.targets: results |= parse_binding(ctx, sym_table, used_sym_table, target, ty) return results, results, False def parse_if_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], node): test = parse_expr(ctx, sym_table, node.test) if test != ctx.types['bool']: raise CustomError(f'condition must be bool instead of {test}') a, b, r = parse_stmts(ctx, sym_table, used_sym_table, node.body) used_sym_table |= b a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1