diff --git a/toy-impl/test_top_level.py b/toy-impl/test_top_level.py new file mode 100644 index 0000000..c5b72f5 --- /dev/null +++ b/toy-impl/test_top_level.py @@ -0,0 +1,31 @@ +import ast +from type_def import * +from top_level import * + +test = """ +class A: + a: int + def foo(a: B) -> int: + pass + +class B(A): + a: str + def bar(a: list[list[virtual[A]]]) -> A: + pass +""" + +variables = {'X': TypeVariable('X', []), 'Y': TypeVariable('Y', [])} +types = {'int': PrimitiveType('int'), 'str': PrimitiveType('str')} +ctx = Context(variables, types) + +ctx, functions, _ = parse_top_level(ctx, ast.parse(test)) + +for name, t in ctx.types.items(): + if isinstance(t, ClassType): + print(f"class {t.name}") + for name, ty in t.fields.items(): + print(f" {name}: {ty}") + for name, (args, result, _) in t.methods.items(): + print(f" {name}: ({', '.join([str(v) for v in args])}) -> {result}") + + diff --git a/toy-impl/top_level.py b/toy-impl/top_level.py new file mode 100644 index 0000000..0f92c77 --- /dev/null +++ b/toy-impl/top_level.py @@ -0,0 +1,141 @@ +import ast +from type_def import * + +class CustomError(Exception): + def __init__(self, msg): + self.msg = msg + +class Context: + variables: dict[str, TypeVariable] + types: dict[Type] + + def __init__(self, variables, types): + self.variables = variables + self.types = types + + +def parse_type(ctx: Context, ty): + if ty is None: + return None + elif isinstance(ty, ast.Name): + if ty.id in ctx.types: + return ctx.types[ty.id], set() + elif ty.id in ctx.variables: + return ctx.variables[ty.id], {ty.id} + else: + raise CustomError(f"Unbounded Type {ty.id}") + elif isinstance(ty, ast.Subscript): + if isinstance(ty.value, ast.Name): + generic = ty.value.id + else: + raise CustomError(f"Unknown Generic Type {ty.value}") + if not isinstance(ty.slice, ast.Name) and not isinstance(ty.slice, ast.Tuple) \ + and not isinstance(ty.slice, ast.Subscript): + raise CustomError(f"Generic Type of the form {ty.slice} is not supported") + if generic == 'tuple': + if not isinstance(ty.slice, ast.Tuple): + raise CustomError(f"Generic Type of the form {ty} is not supported") + param = [] + var = set() + for t in ty.slice.elts: + p, v = parse_type(ctx, t) + param.append(p) + var |= v + return TupleType(param), var + elif generic == 'list': + param, var = parse_type(ctx, ty.slice) + return ListType(param), var + elif generic == 'virtual': + param, var = parse_type(ctx, ty.slice) + if not isinstance(param, ClassType): + raise CustomError(f"Parameter of virtual must be a class instead of {param}") + return VirtualClassType(param), var + else: + raise CustomError(f"Unknown Generic Type {ty.value}") + + +def parse_function(ctx: Context, base, fn: ast.FunctionDef): + args = [] + var = set() + for arg in fn.args.args: + name = arg.arg + ty, v = parse_type(ctx, arg.annotation) + var |= v + if name == 'self' and ty is None and base is not None: + ty = base + args.append(ty) + result, v = parse_type(ctx, fn.returns) + if len(v - var) > 0: + raise CustomError(f"Unbounded variable in return type of {fn.name}") + return args, result, var + + +def parse_class(ctx, c: ast.ClassDef): + node = ctx.types[c.name] + functions = [] + + for base in c.bases: + if not isinstance(base, ast.Name): + raise CustomError(f"Base class of the form {base} is not supported") + name = base.id + if name not in ctx.types: + raise CustomError(f"Unbounded base class name {base}") + if not isinstance(ctx.types[name], ClassType): + raise CustomError(f"Base class must be a class instead of {base}") + node.parents.append(ctx.types[name]) + + for stmt in c.body: + if isinstance(stmt, ast.AnnAssign): + if not isinstance(stmt.target, ast.Name): + raise CustomError(f"Assignment of the form {stmt.target} is not supported") + field = stmt.target.id + if field in node.fields: + raise CustomError(f"Duplicated fields {field} in {c.name}") + ty, var = parse_type(ctx, stmt.annotation) + if len(var) > 0: + raise CustomError(f"Type variable is not allowed in class fields") + if ty == None: + raise CustomError(f"{field} of {c.name} is not annotated") + node.fields[field] = ty + elif isinstance(stmt, ast.FunctionDef): + name = stmt.name + if name in node.methods: + raise CustomError(f"Duplicated method {name} in {c.name}") + args, result, var = parse_function(ctx, node, stmt) + node.methods[name] = (args, result, var) + functions.append((c.name, name, stmt)) + else: + raise CustomError(f"{stmt} is not supported") + return functions + + +def parse_top_level(ctx: Context, module: ast.Module): + to_be_processed = [] + # first pass, obtain all type names + for element in module.body: + if isinstance(element, ast.ClassDef): + name = element.name + if name in ctx.types or name in ctx.variables: + raise CustomError(f"Duplicated class name: {name}") + ctx.types[name] = ClassType(name) + to_be_processed.append(element) + elif isinstance(element, ast.FunctionDef): + to_be_processed.append(element) + + # second pass, obtain all function types + functions = {} + function_stmts = [] + for element in to_be_processed: + if isinstance(element, ast.ClassDef): + function_stmts += parse_class(ctx, element) + elif isinstance(element, ast.FunctionDef): + name = element.name + if name in functions: + raise CustomError(f"Duplicated function name {name}") + args, result, var = parse_function(ctx, None, element) + functions[name] = (args, result, var) + function_stmts += element + + return ctx, functions, function_stmts + + diff --git a/toy-impl/type_def.py b/toy-impl/type_def.py new file mode 100644 index 0000000..21c981e --- /dev/null +++ b/toy-impl/type_def.py @@ -0,0 +1,74 @@ +class Type: + pass + + +class BotType: + pass + + +class PrimitiveType(Type): + name: str + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return self.name + + +class TypeVariable(Type): + name: str + constraints: list[Type] + + def __init__(self, name: str, constraints: list[Type]): + self.name = name + self.constraints = constraints + + def __str__(self): + return self.name + + +class ClassType(Type): + name: str + parents: list['ClassType'] + methods: dict[str, tuple[list[Type], Type, set[str]]] + fields: dict[str, Type] + + def __init__(self, name: str): + self.name = name + self.parents = [] + self.methods = {} + self.fields = {} + + def __str__(self): + return self.name + + +class VirtualClassType(Type): + base: ClassType + + def __init__(self, base: ClassType): + self.base = base + + def __str__(self): + return f"virtual[{self.base}]" + +class ListType(Type): + elements: Type + + def __init__(self, elements: Type): + self.elements = elements + + def __str__(self): + return f"list[{self.elements}]" + +class TupleType(Type): + elements: list[Type] + + def __init__(self, elements: Type): + self.elements = elements + + def __str__(self): + return f"tuple[{', '.join([str(v) for v in self.elements])}]" + +