From 5f2eb1c10c73651d9e3cf07a8a2badeebd01ffd3 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 17 Dec 2020 17:01:31 +0800 Subject: [PATCH] added some inference --- toy-impl/inference.py | 58 ++++++++++++++++++++++++++++++++++++++++++ toy-impl/test_subst.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 toy-impl/inference.py create mode 100644 toy-impl/test_subst.py diff --git a/toy-impl/inference.py b/toy-impl/inference.py new file mode 100644 index 0000000..8159201 --- /dev/null +++ b/toy-impl/inference.py @@ -0,0 +1,58 @@ +from type_def import * + +def find_subst(ctx: dict[str, Type], + sub: dict[str, Type], + a: Type, + b: Type): + """ + Find substitution s such that ctx(a) = s(sub(ctx(b))) + return error message if type mismatch + """ + # is error + if isinstance(sub, str): + return sub + + if isinstance(a, TypeVariable) and a.name in ctx: + a = ctx[a.name] + + if isinstance(b, TypeVariable): + if b.name in ctx: + b = ctx[b.name] + elif b.name in sub: + b = sub[b.name] + else: + if len(b.constraints) > 0: + if isinstance(a, TypeVariable): + if len(a.constraints) == 0: + return f"{b} cannot take value of an unconstrained variable {a}" + diff = [v for v in a.constraints if v not in b.constraints] + if len(diff) > 0: + over = ', '.join([str(v) for v in diff]) + return f"{b} cannot take value of {a} as {a} can range over [{over}]" + else: + if a not in b.constraints: + return f"{b} cannot take value of {a}" + sub[b.name] = a + return sub + + if isinstance(a, TypeVariable): + return f"{a} can take values other than {b}" + + if isinstance(a, BotType): + return sub + if type(a) == type(b): + if isinstance(a, ParametricType): + old = sub + for x, y in zip(a.params, b.params): + old = find_subst(ctx, old, x, y) + return old + elif isinstance(a, ClassType) or isinstance(a, PrimitiveType): + if a.name == b.name: + return sub + elif isinstance(a, VirtualClassType): + return find_subst(a.base, b.base) + else: + raise Exception() + return f"{a} != {b}" + + diff --git a/toy-impl/test_subst.py b/toy-impl/test_subst.py new file mode 100644 index 0000000..451d07f --- /dev/null +++ b/toy-impl/test_subst.py @@ -0,0 +1,40 @@ +from type_def import * +from inference import * + +types = { + 'int32': PrimitiveType('int32'), + 'int64': PrimitiveType('int64'), + 'str': PrimitiveType('str'), +} + +variables = { + 'X': TypeVariable('X', []), + 'Y': TypeVariable('Y', []), + 'I': TypeVariable('I', [types['int32'], types['int64']]), + 'A': TypeVariable('A', [types['int32'], types['int64'], types['str']]), +} + +def stringify_subst(subst): + if isinstance(subst, str): + return subst + elements = [f"{key}: {str(value)}" for key, value in subst.items()] + return "{" + ', '.join(elements) + "}" + +def try_case(a, b, ctx): + result = find_subst(ctx, {}, a, b) + print(f"{a} <- {b} w.r.t. {stringify_subst(ctx)}\n {stringify_subst(result)}\n") + + +try_case(types['int32'], types['int32'], {}) +try_case(types['int32'], types['int64'], {}) +try_case(types['int32'], variables['X'], {}) +try_case(types['int32'], variables['X'], {'X': types['int32']}) +try_case(types['int32'], variables['X'], {'X': types['int64']}) +try_case(variables['X'], variables['X'], {'X': types['int64']}) +try_case(variables['X'], variables['Y'], {'Y': types['int64']}) +try_case(variables['X'], variables['Y'], {'X': types['int64']}) +try_case(variables['I'], variables['X'], {}) +try_case(variables['I'], variables['A'], {}) +try_case(variables['A'], variables['I'], {}) +try_case(variables['X'], variables['I'], {}) +