import ast import copy from helper import * from type_def import * from inference import * from parse_expr import parse_expr, parse_simple_binding def parse_stmts(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, nodes): sym_table2 = copy.copy(sym_table) used_sym_table2 = copy.copy(used_sym_table) for node in nodes: if isinstance(node, ast.Assign): a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.If): a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.While): a, b, returned = parse_while_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.For): a, b, returned = parse_for_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.Return): a, b, returned = parse_return_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.Break) or isinstance(node, ast.Continue): continue else: raise CustomError(f'{node} is not supported yet') sym_table2 |= a used_sym_table2 |= b if returned: return sym_table2, used_sym_table2, True return sym_table2, used_sym_table2, False def get_target_type(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], target): if isinstance(target, ast.Subscript): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if not isinstance(t, ListType): raise CustomError(f'cannot index through type {t}') if isinstance(target.slice, ast.Slice): raise CustomError(f'assignment to slice is not supported') i = parse_expr(ctx, sym_table, target.slice) if i != ctx.types['int32']: raise CustomError(f'index must be int32') return t.params[0] elif isinstance(target, ast.Attribute): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if target.attr not in t.fields: raise CustomError(f'{t} has no field {target.attr}') return t.fields[target.attr] elif isinstance(target, ast.Name): if target.id not in sym_table: raise CustomError(f'unbounded {target.id}') return sym_table[target.id] else: raise CustomError(f'assignment to {target} is not supported') def parse_stmt_binding(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], target, ty): if isinstance(target, ast.Name): if target.id in used_sym_table: if used_sym_table[target.id] != ty: raise CustomError('inconsistent type') if target.id == '_': return {} return {target.id: ty} elif isinstance(target, ast.Tuple): if not isinstance(ty, TupleType): raise CustomError(f'cannot pattern match over {ty}') if len(target.elts) != len(ty.params): raise CustomError(f'pattern matching length mismatch') result = {} for x, y in zip(target.elts, ty.params): new = parse_stmt_binding(ctx, sym_table, used_sym_table, x, y) old_len = len(result) result |= new used_sym_table |= new if len(result) != old_len + len(new): raise CustomError(f'variable name clash') return result else: t = get_target_type(ctx, sym_table, used_sym_table, target) if ty != t: raise CustomError(f'inconsistent type') return {} def parse_assign(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): # permitted assignment targets: # variables, class fields, list elements # function evaluation is only allowed within list index # may relax later after when we settle on lifetime handling ty = parse_expr(ctx, sym_table, node.value) results = {} for target in node.targets: results |= parse_stmt_binding(ctx, sym_table, used_sym_table, target, ty) return results, results, False def parse_if_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): test = parse_expr(ctx, sym_table, node.test) if test != ctx.types['bool']: raise CustomError(f'condition must be bool instead of {test}') a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) used_sym_table |= b a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1 def parse_for_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): ty = parse_expr(ctx, sym_table, node.iter) if not isinstance(ty, ListType): raise CustomError('only iteration over list is supported') binding = parse_simple_binding(node.target, ty.params[0]) for key, value in binding.items(): if key in used_sym_table: if value != used_sym_table[key]: raise CustomError('inconsistent type') a, b, r = parse_stmts(ctx, sym_table | binding, used_sym_table | binding, return_ty, node.body) a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table | b, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1 def parse_while_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): ty = parse_expr(ctx, sym_table, node.test) if ty != ctx.types['bool']: raise CustomError('condition must be bool') # more sophisticated return analysis is needed... a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table | b, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1 def parse_return_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): if return_ty is None: if node.value is not None: raise CustomError('no return value is allowed') return {}, {}, True ty = parse_expr(ctx, sym_table, node.value) if ty != return_ty: raise CustomError(f'expected returning {return_ty} but got {ty}') return {}, {}, True