diff --git a/toy-impl/examples/a.py b/toy-impl/examples/a.py index 11616ff..3b65f29 100644 --- a/toy-impl/examples/a.py +++ b/toy-impl/examples/a.py @@ -13,6 +13,9 @@ class Vec: else: return Vec([self.v[i] + other.v[i] for i in range(len(self.v))]) + def get(self, index: int32) -> int32: + return self.v.head() + T = TypeVar('T', int32, list[int32]) diff --git a/toy-impl/inference.py b/toy-impl/inference.py index 0195f80..02af224 100644 --- a/toy-impl/inference.py +++ b/toy-impl/inference.py @@ -88,8 +88,10 @@ def resolve_call(obj, # TODO: we may want to return the substitution, for monomorphization... f_args = None f_result = None + subst = {} if obj is not None: obj = obj.subst(assumptions) + subst = obj.get_subst() if obj is None: if fn in ctx.functions: f = ctx.functions[fn] @@ -107,7 +109,8 @@ def resolve_call(obj, f_args, f_result = TupleType([]), c else: raise CustomError(f"No such function {fn}") - elif isinstance(obj, PrimitiveType) or isinstance(obj, ClassType): + elif isinstance(obj, PrimitiveType) or isinstance(obj, ClassType) \ + or isinstance(obj, ParametricType): if fn in obj.methods: f = obj.methods[fn] if len(f[0]) == 0 or (not isinstance(f[0][0], SelfType) and \ @@ -154,7 +157,7 @@ def resolve_call(obj, raise CustomError("Divergent type after constraints substitution") a = TupleType(args) - subst = find_subst(assumptions, {}, a, f_args) + subst = find_subst(assumptions, subst, a, f_args) if isinstance(subst, str): raise CustomError(f"type check failed: {subst}") result = f_result.subst(subst) diff --git a/toy-impl/primitives.py b/toy-impl/primitives.py index 5b3edd6..e609a7f 100644 --- a/toy-impl/primitives.py +++ b/toy-impl/primitives.py @@ -66,6 +66,11 @@ i32.methods['__init__'] = ([SelfType(), I], None, {'I'}) i64.methods['__init__'] = ([SelfType(), I], None, {'I'}) f32.methods['__init__'] = ([SelfType(), I], None, {'I'}) +ParametricType.method_table['list'] = { + 'get': ([SelfType(), i32], TypeVariable('T1', []), set()), + 'head': ([SelfType()], TypeVariable('T1', []), set()) +} + simplest_ctx = Context({}, types) simplest_ctx.functions['len'] = ([ListType(A)], i32, {'A'}) simplest_ctx.functions['range'] = ([i32], ListType(i32), set()) diff --git a/toy-impl/type_def.py b/toy-impl/type_def.py index 024000b..47b2e58 100644 --- a/toy-impl/type_def.py +++ b/toy-impl/type_def.py @@ -23,6 +23,9 @@ class Type: return tv return self + def get_subst(self): + return {} + class BotType: def __eq__(self, other): @@ -113,10 +116,16 @@ class VirtualClassType(Type): class ParametricType(Type): params: list[Type] + name: str - def __init__(self, params: list[Type]): + method_table = {} + + def __init__(self, name, params: list[Type]): super().__init__() self.params = params + self.name = name + if name in ParametricType.method_table: + self.methods = ParametricType.method_table[name] def __eq__(self, other): if type(self) != type(other) or len(self.params) != len(other.params): @@ -148,9 +157,12 @@ class ParametricType(Type): s.params = [v.inv_subst(subst) for v in self.params] return s + def get_subst(self): + return {f'T{i + 1}': v for i, v in enumerate(self.params)} + class ListType(ParametricType): def __init__(self, param: Type): - super().__init__([param]) + super().__init__('list', [param]) def __str__(self): return f"list[{self.params[0]}]" @@ -158,7 +170,7 @@ class ListType(ParametricType): class TupleType(ParametricType): def __init__(self, params: list[Type]): - super().__init__(params) + super().__init__('tuple', params) def __str__(self): return f"tuple[{', '.join([str(v) for v in self.params])}]"