from type_def import * def find_subst(ctx: dict[str, Type], sub: dict[str, Type], a: Type, b: Type): """ Find substitution s such that ctx(a) = s(sub(ctx(b))) return error message if type mismatch """ # is error if isinstance(sub, str): return sub if isinstance(a, TypeVariable) and a.name in ctx: a = ctx[a.name] if isinstance(b, TypeVariable): if b.name in ctx: b = ctx[b.name] elif b.name in sub: b = sub[b.name] else: if len(b.constraints) > 0: if isinstance(a, TypeVariable): if len(a.constraints) == 0: return f"{b} cannot take value of an unconstrained variable {a}" diff = [v for v in a.constraints if v not in b.constraints] if len(diff) > 0: over = ', '.join([str(v) for v in diff]) return f"{b} cannot take value of {a} as {a} can range over [{over}]" else: if a not in b.constraints: return f"{b} cannot take value of {a}" sub[b.name] = a return sub if isinstance(a, TypeVariable): return f"{a} can take values other than {b}" if isinstance(a, BotType): return sub if type(a) == type(b): if isinstance(a, ParametricType): old = sub for x, y in zip(a.params, b.params): old = find_subst(ctx, old, x, y) return old elif isinstance(a, ClassType) or isinstance(a, PrimitiveType): if a.name == b.name: return sub elif isinstance(a, VirtualClassType): return find_subst(a.base, b.base) else: raise Exception() return f"{a} != {b}"