import ast from type_def import * class CustomError(Exception): def __init__(self, msg): self.msg = msg class Context: variables: dict[str, TypeVariable] types: dict[Type] def __init__(self, variables, types): self.variables = variables self.types = types def parse_type(ctx: Context, ty): if ty is None: return None, set() elif isinstance(ty, ast.Name): # we should support string either, but no need for toy implementaiton if ty.id in ctx.types: return ctx.types[ty.id], set() elif ty.id in ctx.variables: return ctx.variables[ty.id], {ty.id} else: raise CustomError(f"Unbounded Type {ty.id}") elif isinstance(ty, ast.Subscript): if isinstance(ty.value, ast.Name): generic = ty.value.id else: raise CustomError(f"Unknown Generic Type {ty.value}") if not isinstance(ty.slice, ast.Name) and not isinstance(ty.slice, ast.Tuple) \ and not isinstance(ty.slice, ast.Subscript): raise CustomError(f"Generic Type of the form {ty.slice} is not supported") if generic == 'tuple': if not isinstance(ty.slice, ast.Tuple): raise CustomError(f"Generic Type of the form {ty} is not supported") param = [] var = set() for t in ty.slice.elts: p, v = parse_type(ctx, t) param.append(p) var |= v return TupleType(param), var elif generic == 'list': param, var = parse_type(ctx, ty.slice) return ListType(param), var elif generic == 'virtual': param, var = parse_type(ctx, ty.slice) if not isinstance(param, ClassType): raise CustomError(f"Parameter of virtual must be a class instead of {param}") return VirtualClassType(param), var else: raise CustomError(f"Unknown Generic Type {ty.value}") else: raise CustomError(f"Unrecognized Type {ty}") def parse_function(ctx: Context, base, fn: ast.FunctionDef): args = [] var = set() for arg in fn.args.args: name = arg.arg ty, v = parse_type(ctx, arg.annotation) var |= v if name == 'self' and ty is None and base is not None: ty = base args.append(ty) result, v = parse_type(ctx, fn.returns) if len(v - var) > 0: raise CustomError(f"Unbounded variable in return type of {fn.name}") return args, result, var def parse_class(ctx, c: ast.ClassDef): node = ctx.types[c.name] functions = [] for base in c.bases: if not isinstance(base, ast.Name): raise CustomError(f"Base class of the form {base} is not supported") name = base.id if name not in ctx.types: raise CustomError(f"Unbounded base class name {base}") if not isinstance(ctx.types[name], ClassType): raise CustomError(f"Base class must be a class instead of {base}") node.parents.append(ctx.types[name]) for stmt in c.body: if isinstance(stmt, ast.AnnAssign): if not isinstance(stmt.target, ast.Name): raise CustomError(f"Assignment of the form {stmt.target} is not supported") field = stmt.target.id if field in node.fields: raise CustomError(f"Duplicated fields {field} in {c.name}") ty, var = parse_type(ctx, stmt.annotation) if len(var) > 0: raise CustomError(f"Type variable is not allowed in class fields") if ty == None: raise CustomError(f"{field} of {c.name} is not annotated") node.fields[field] = ty elif isinstance(stmt, ast.FunctionDef): name = stmt.name if name in node.methods: raise CustomError(f"Duplicated method {name} in {c.name}") args, result, var = parse_function(ctx, node, stmt) node.methods[name] = (args, result, var) functions.append((c.name, name, stmt)) else: raise CustomError(f"{stmt} is not supported") return functions def parse_top_level(ctx: Context, module: ast.Module): to_be_processed = [] # first pass, obtain all type names for element in module.body: if isinstance(element, ast.ClassDef): name = element.name if name in ctx.types or name in ctx.variables: raise CustomError(f"Duplicated class name: {name}") ctx.types[name] = ClassType(name) to_be_processed.append(element) elif isinstance(element, ast.FunctionDef): to_be_processed.append(element) # second pass, obtain all function types functions = {} function_stmts = [] for element in to_be_processed: if isinstance(element, ast.ClassDef): function_stmts += parse_class(ctx, element) elif isinstance(element, ast.FunctionDef): name = element.name if name in functions: raise CustomError(f"Duplicated function name {name}") args, result, var = parse_function(ctx, None, element) functions[name] = (args, result, var) function_stmts += element return ctx, functions, function_stmts