diff --git a/toy-impl/helper.py b/toy-impl/helper.py new file mode 100644 index 0000000..c49ac06 --- /dev/null +++ b/toy-impl/helper.py @@ -0,0 +1,9 @@ +class CustomError(Exception): + def __init__(self, msg): + self.msg = msg + +def stringify_subst(subst): + if isinstance(subst, str): + return subst + elements = [f"{key}: {str(value)}" for key, value in subst.items()] + return "{" + ', '.join(elements) + "}" diff --git a/toy-impl/inference.py b/toy-impl/inference.py index 9be17ad..976564d 100644 --- a/toy-impl/inference.py +++ b/toy-impl/inference.py @@ -1,12 +1,15 @@ from type_def import * +from helper import * def find_subst(ctx: dict[str, Type], sub: dict[str, Type], a: Type, b: Type): """ - Find substitution s such that ctx(a) = s(sub(ctx(b))) - return error message if type mismatch + Find substitution s such that ctx(a) = s(sub(b)). + Note that variables in a and b are considered independent. + return s.sub if such s exists (. means function composition). + return error message if type mismatch. """ # is error if isinstance(sub, str): @@ -16,9 +19,7 @@ def find_subst(ctx: dict[str, Type], a = ctx[a.name] if isinstance(b, TypeVariable): - if b.name in ctx: - b = ctx[b.name] - elif b.name in sub: + if b.name in sub: b = sub[b.name] else: if len(b.constraints) > 0: @@ -32,10 +33,6 @@ def find_subst(ctx: dict[str, Type], else: if a not in b.constraints: return f"{b} cannot take value of {a}" - if a == b: - return sub - if b in a.get_vars(): - return "Recursive type is not supported" sub[b.name] = a return sub @@ -47,8 +44,12 @@ def find_subst(ctx: dict[str, Type], if isinstance(a, BotType): return sub + # TODO: virtual type is not handled currently + # we need to access the class dictionary to handle this if type(a) == type(b): if isinstance(a, ParametricType): + if len(a.params) != len(b.params): + return f"{a} != {b}" old = sub for x, y in zip(a.params, b.params): old = find_subst(ctx, old, x, y) @@ -62,4 +63,86 @@ def find_subst(ctx: dict[str, Type], raise Exception() return f"{a} != {b}" +def resolve_call(obj, + fn: str, + args: list[Type], + assumptions: dict[str, Type], + ctx: Context) -> tuple[Type]: + # TODO: we may want to return the substitution, for monomorphization... + f_args = None + f_result = None + if obj is not None: + obj = obj.subst(assumptions) + if obj is None: + if fn in ctx.functions: + f = ctx.functions[fn] + f_args, f_result = TupleType(f[0]), f[1] + elif fn in ctx.types: + c = ctx.types[fn] + if '__init__' in c.methods: + f = c.methods['__init__'] + if not isinstance(f[0][0], SelfType) or f[1] is not None: + raise CustomError( + f'__init__ of {c} should accept self and return None' + ) + f_args, f_result = TupleType(f[0][1:]), c + else: + f_args, f_result = TupleType([]), c + else: + raise CustomError(f"No such function {fn}") + elif isinstance(obj, PrimitiveType) or isinstance(obj, ClassType): + if fn in obj.methods: + f = obj.methods[fn] + if len(f[0]) == 0 or (not isinstance(f[0][0], SelfType) and \ + f[0][0] != obj): + raise CustomError('{f} is not a method of {obj}') + f_args, f_result = TupleType(f[0][1:]), f[1] + else: + raise CustomError(f"No such method {fn} in {c}") + elif isinstance(obj, VirtualClassType): + # may need to emit special annotation that this is a virtual method + # call? + if fn in obj.base.methods: + f = obj.base.methods[fn] + if len(f[0]) == 0 or not isinstance(f[0][0], SelfType): + raise CustomError('{f} is not a method of {obj}') + f_args, f_result = TupleType(f[0][1:]), f[1] + else: + raise CustomError(f"No such method {fn} in {c}") + elif isinstance(obj, TypeVariable): + # if not constrained, error. otherwise, try all values, and only allow + # if the results are the same or if they are the same modulo the + # substitution. + # expensive operation, but cache should be applicable + # in order to cache this, our cache must be able to compare equality + # modulo variable naming... probably not easy either + if len(obj.constraints) == 0: + raise CustomError("no methods for unconstrained object") + results = [resolve_call(obj, fn, args, assumptions | {obj.name: v}, ctx) + for v in obj.assumptions] + for v in results[1:]: + if v != results[0]: + break + else: + # same result + return results[0] + results = [v.inv_subst([a, obj]) + for v, a in zip(results, obj.assumptions)] + for v in results[1:]: + if v != results[0]: + break + else: + # same result + return results[0] + raise CustomError("Divergent type after constraints substitution") + + a = TupleType(args) + subst = find_subst(assumptions, {}, a, f_args) + if isinstance(subst, str): + raise CustomError(f"type check failed: {subst}") + result = f_result.subst(subst) + if isinstance(result, SelfType): + return obj + else: + return result diff --git a/toy-impl/test_inference.py b/toy-impl/test_inference.py new file mode 100644 index 0000000..e18ccb7 --- /dev/null +++ b/toy-impl/test_inference.py @@ -0,0 +1,60 @@ +from type_def import * +from inference import * +from helper import * + +types = { + 'int32': PrimitiveType('int32'), + 'int64': PrimitiveType('int64'), + 'str': PrimitiveType('str'), +} + +i32 = types['int32'] +i64 = types['int64'] +s = types['str'] + + +variables = { + 'X': TypeVariable('X', []), + 'Y': TypeVariable('Y', []), + 'I': TypeVariable('I', [i32, i64]), + 'A': TypeVariable('A', [i32, i64, s]), +} + +X = variables['X'] +Y = variables['Y'] +I = variables['I'] +A = variables['A'] + +i32.methods['__init__'] = ([SelfType(), I], None, set()) +i32.methods['__add__'] = ([SelfType(), i32], i32, set()) +i32.methods['__sub__'] = ([SelfType(), i32], i32, set()) + +i64.methods['__init__'] = ([SelfType(), I], None, set()) +i64.methods['__add__'] = ([SelfType(), i64], i64, set()) +i64.methods['__sub__'] = ([SelfType(), i64], i64, set()) + +ctx = Context(variables, types) + + +def test_call(obj, fn, args, assumptions = {}): + args_str = ', '.join([str(v) for v in args]) + obj_str = '' if obj is None else str(obj) + '.' + print(f'Testing {obj_str}{fn}({args_str}) w.r.t. {stringify_subst(assumptions)}') + try: + result = resolve_call(obj, fn, args, assumptions, ctx) + print(result) + except CustomError as err: + print(f'error: {err.msg}') + +test_call(None, 'int32', []) +test_call(None, 'int32', [i32]) +test_call(None, 'int32', [i64]) +test_call(None, 'int32', [I]) +test_call(None, 'int32', [A]) +test_call(None, 'int32', [i32, i64]) +test_call(i32, '__add__', []) +test_call(i32, '__add__', [i32]) +test_call(i32, '__add__', [i64]) +test_call(i32, '__add__', [i32, i32]) + + diff --git a/toy-impl/test_subst.py b/toy-impl/test_subst.py index c852ec9..61164b1 100644 --- a/toy-impl/test_subst.py +++ b/toy-impl/test_subst.py @@ -1,5 +1,6 @@ from type_def import * from inference import * +from helper import * types = { 'int32': PrimitiveType('int32'), @@ -49,6 +50,7 @@ try_case(A, I, {}) try_case(X, I, {}) try_case(ListType(i32), TupleType([i32]), {}) try_case(TupleType([i32]), ListType(i32), {}) +try_case(TupleType([i32, i32]), TupleType([i32]), {}) try_case(ListType(i32), ListType(i32), {}) try_case(TupleType([X, X]), TupleType([X, Y]), {}) try_case(TupleType([X, X]), TupleType([Y, Y]), {}) @@ -57,4 +59,3 @@ try_case(TupleType([X, X]), TupleType([X, X]), {}) try_case(TupleType([X, Y]), X, {}) try_case(TupleType([i32, Y]), X, {}) - diff --git a/toy-impl/test_top_level.py b/toy-impl/test_top_level.py index f30dd30..bd7bbf4 100644 --- a/toy-impl/test_top_level.py +++ b/toy-impl/test_top_level.py @@ -10,7 +10,7 @@ class A: class B(A): a: str - def bar(a: list[list[virtual[A]]]) -> A: + def bar(self, a: list[list[virtual[A]]]) -> self: pass """ @@ -18,7 +18,7 @@ variables = {'X': TypeVariable('X', []), 'Y': TypeVariable('Y', [])} types = {'int': PrimitiveType('int'), 'str': PrimitiveType('str')} ctx = Context(variables, types) -ctx, functions, _ = parse_top_level(ctx, ast.parse(test)) +ctx, _ = parse_top_level(ctx, ast.parse(test)) for name, t in ctx.types.items(): if isinstance(t, ClassType): diff --git a/toy-impl/top_level.py b/toy-impl/top_level.py index 615dd9b..de77e61 100644 --- a/toy-impl/top_level.py +++ b/toy-impl/top_level.py @@ -1,17 +1,6 @@ import ast from type_def import * - -class CustomError(Exception): - def __init__(self, msg): - self.msg = msg - -class Context: - variables: dict[str, TypeVariable] - types: dict[Type] - - def __init__(self, variables, types): - self.variables = variables - self.types = types +from helper import * def parse_type(ctx: Context, ty): @@ -65,9 +54,13 @@ def parse_function(ctx: Context, base, fn: ast.FunctionDef): ty, v = parse_type(ctx, arg.annotation) var |= v if name == 'self' and ty is None and base is not None: - ty = base + ty = SelfType() args.append(ty) - result, v = parse_type(ctx, fn.returns) + if isinstance(fn.returns, ast.Name) and fn.returns.id == 'self'\ + and base is not None: + result, v = SelfType(), set() + else: + result, v = parse_type(ctx, fn.returns) if len(v - var) > 0: raise CustomError(f"Unbounded variable in return type of {fn.name}") return args, result, var @@ -126,7 +119,6 @@ def parse_top_level(ctx: Context, module: ast.Module): to_be_processed.append(element) # second pass, obtain all function types - functions = {} function_stmts = [] for element in to_be_processed: if isinstance(element, ast.ClassDef): @@ -135,10 +127,12 @@ def parse_top_level(ctx: Context, module: ast.Module): name = element.name if name in functions: raise CustomError(f"Duplicated function name {name}") + if name in ctx.types: + raise CustomError(f"Function name {name} clashed with type name") args, result, var = parse_function(ctx, None, element) - functions[name] = (args, result, var) + ctx.functions[name] = (args, result, var) function_stmts += element - return ctx, functions, function_stmts + return ctx, function_stmts diff --git a/toy-impl/type_def.py b/toy-impl/type_def.py index 262d408..49b37fb 100644 --- a/toy-impl/type_def.py +++ b/toy-impl/type_def.py @@ -1,10 +1,28 @@ +import copy + class Type: + methods: dict[str, tuple[list['Type'], 'Type', set[str]]] + fields: dict[str, 'Type'] + + def __init__(self): + self.methods = {} + self.fields = {} + def __eq__(self, other): return False def get_vars(self): return [] + def subst(self, subst: dict[str, 'Type']): + return self + + def inv_subst(self, subst: list[tuple['Type', 'TypeVariable']]): + for t, tv in subst: + if self == t: + return tv + return self + class BotType: def __eq__(self, other): @@ -15,6 +33,7 @@ class PrimitiveType(Type): name: str def __init__(self, name: str): + super().__init__() self.name = name def __str__(self): @@ -32,6 +51,7 @@ class TypeVariable(Type): constraints: list[Type] def __init__(self, name: str, constraints: list[Type]): + super().__init__() self.name = name self.constraints = constraints @@ -44,6 +64,11 @@ class TypeVariable(Type): def get_vars(self): return [self] + def subst(self, subst: dict[str, Type]): + if self.name in subst: + return subst[self.name] + return self + class ClassType(Type): name: str @@ -52,10 +77,9 @@ class ClassType(Type): fields: dict[str, Type] def __init__(self, name: str): + super().__init__() self.name = name self.parents = [] - self.methods = {} - self.fields = {} def __str__(self): return self.name @@ -63,10 +87,17 @@ class ClassType(Type): def __eq__(self, other): return isinstance(other, ClassType) and self.name == other.name + +class SelfType(Type): + def __str__(self): + return 'self' + + class VirtualClassType(Type): base: ClassType def __init__(self, base: ClassType): + super().__init__() self.base = base def __str__(self): @@ -79,6 +110,7 @@ class ParametricType(Type): params: list[Type] def __init__(self, params: list[Type]): + super().__init__() self.params = params def __eq__(self, other): @@ -98,6 +130,19 @@ class ParametricType(Type): result.append(v) return result + def subst(self, subst: dict[str, Type]): + s = copy.copy(self) + s.params = [v.subst(subst) for v in self.params] + return s + + def inv_subst(self, subst: list[tuple['Type', 'TypeVariable']]): + for t, tv in subst: + if self == t: + return tv + s = copy.copy(self) + s.params = [v.inv_subst(subst) for v in self.params] + return s + class ListType(ParametricType): def __init__(self, param: Type): super().__init__([param]) @@ -114,3 +159,13 @@ class TupleType(ParametricType): return f"tuple[{', '.join([str(v) for v in self.params])}]" +class Context: + variables: dict[str, TypeVariable] + types: dict[str, Type] + functions: dict[str, tuple[list[Type], Type, set[str]]] + + def __init__(self, variables, types): + self.variables = variables + self.types = types + self.functions = {} +