import ast import copy from helper import * from type_def import * from inference import * from top_level import parse_type from parse_expr import parse_expr, parse_simple_binding def parse_stmts(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, nodes): sym_table2 = copy.copy(sym_table) used_sym_table2 = copy.copy(used_sym_table) for node in nodes: if isinstance(node, ast.Assign): a, b, returned = parse_assign(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.If): a, b, returned = parse_if_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.While): a, b, returned = parse_while_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.For): a, b, returned = parse_for_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.Return): a, b, returned = parse_return_stmt(ctx, sym_table2, used_sym_table2, return_ty, node) elif isinstance(node, ast.Break) or isinstance(node, ast.Continue): continue else: raise CustomError(f'{node} is not supported yet', node) sym_table2 |= a used_sym_table2 |= b if returned: return sym_table2, used_sym_table2, True return sym_table2, used_sym_table2, False def get_target_type(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], target): if isinstance(target, ast.Subscript): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if not isinstance(t, ListType): raise CustomError(f'cannot index through type {t}', target) if isinstance(target.slice, ast.Slice): raise CustomError(f'assignment to slice is not supported', target) i = parse_expr(ctx, sym_table, target.slice) if i != ctx.types['int32']: raise CustomError(f'index must be int32', target.slice) return t.params[0] elif isinstance(target, ast.Attribute): t = get_target_type(ctx, sym_table, used_sym_table, target.value) if target.attr not in t.fields: raise CustomError(f'{t} has no field {target.attr}', target) return t.fields[target.attr] elif isinstance(target, ast.Name): if target.id not in sym_table: raise CustomError(f'unbounded {target.id}', target) return sym_table[target.id] else: raise CustomError(f'assignment to {target} is not supported', target) def parse_stmt_binding(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], target, ty): if isinstance(target, ast.Name): if target.id in used_sym_table: if used_sym_table[target.id] != ty: raise CustomError(f'inconsistent type, {target.id} was ' \ f'defined to be {used_sym_table[target.id]}, ' \ f'but is now {ty}', target) if target.id == '_': return {} return {target.id: ty} elif isinstance(target, ast.Tuple): if not isinstance(ty, TupleType): raise CustomError(f'cannot pattern match over {ty}', target) if len(target.elts) != len(ty.params): raise CustomError(f'pattern matching length mismatch', target) result = {} for x, y in zip(target.elts, ty.params): new = parse_stmt_binding(ctx, sym_table, used_sym_table, x, y) old_len = len(result) result |= new used_sym_table |= new if len(result) != old_len + len(new): raise CustomError(f'variable name clash', target) return result else: t = get_target_type(ctx, sym_table, used_sym_table, target) if ty != t: raise CustomError(f'type mismatch', target) return {} def parse_assign(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): # permitted assignment targets: # variables, class fields, list elements # function evaluation is only allowed within list index # may relax later after when we settle on lifetime handling ty = parse_expr(ctx, sym_table, node.value) results = {} for target in node.targets: results |= parse_stmt_binding(ctx, sym_table, used_sym_table, target, ty) return results, results, False def parse_if_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): if isinstance(node.test, ast.Compare) and \ isinstance(node.test.left, ast.Call) and \ isinstance(node.test.left.func, ast.Name) and \ node.test.left.func.id == 'type' and \ len(node.test.left.args) == 1 and \ len(node.test.ops) == 1 and \ (isinstance(node.test.ops[0], ast.Eq) or\ isinstance(node.test.ops[0], ast.NotEq)): t = parse_expr(ctx, sym_table, node.test.left.args[0]) if not isinstance(t, TypeVariable) or len(t.constraints) < 2: raise CustomError( 'type guard only support basic type variables with constraints', node.test) t1, _ = parse_type(ctx, node.test.comparators[0]) if t1 not in t.constraints: raise CustomError( f'{t1} is not in constraints of {t}', node.test) t2 = [v for v in t.constraints if v != t1] try: t.constraints = [t1] a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) used_sym_table |= b t.constraints = t2 a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1 finally: t.constraints = t else: test = parse_expr(ctx, sym_table, node.test) if test != ctx.types['bool']: raise CustomError(f'condition must be bool instead of {test}', node.test) a, b, r = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) used_sym_table |= b a1, b1, r1 = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, r and r1 def parse_for_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): ty = parse_expr(ctx, sym_table, node.iter) if not isinstance(ty, ListType): raise CustomError('only iteration over list is supported', node.iter) binding = parse_simple_binding(node.target, ty.params[0]) for key, value in binding.items(): if key in used_sym_table: if value != used_sym_table[key]: raise CustomError('inconsistent type', node) # more sophisticated return analysis is needed... a, b, _ = parse_stmts(ctx, sym_table | binding, used_sym_table | binding, return_ty, node.body) a1, b1, _ = parse_stmts(ctx, sym_table, used_sym_table | b, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, False def parse_while_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): ty = parse_expr(ctx, sym_table, node.test) if ty != ctx.types['bool']: raise CustomError('condition must be bool', node.test) # more sophisticated return analysis is needed... a, b, _ = parse_stmts(ctx, sym_table, used_sym_table, return_ty, node.body) a1, b1, _ = parse_stmts(ctx, sym_table, used_sym_table | b, return_ty, node.orelse) defined = {k: a[k] for k in a if k in a1} return defined, b | b1, False def parse_return_stmt(ctx: Context, sym_table: dict[str, Type], used_sym_table: dict[str, Type], return_ty: Type, node): if return_ty is None: if node.value is not None: raise CustomError('no return value is allowed', node) return {}, {}, True ty = parse_expr(ctx, sym_table, node.value) if isinstance(node.value, ast.Name) and \ node.value.id == 'self' and \ 'self' in sym_table and \ isinstance(return_ty, SelfType): return {}, {}, True if ty != return_ty: if isinstance(return_ty, TypeVariable): if len(return_ty.constraints) == 1 and \ return_ty.constraints[0] == ty: return {}, {}, True raise CustomError(f'expected returning {return_ty} but got {ty}', node) return {}, {}, True