diff --git a/type_check.py b/type_check.py new file mode 100644 index 0000000..f70cdc8 --- /dev/null +++ b/type_check.py @@ -0,0 +1,46 @@ +def is_variable(b): + return isinstance(b, str) and b.isupper() + +def unify(ctx, a, b): + """ + a is the concrete type + b is the type with parameter + lower case means primitive type + upper case means type variable + list and tuples are just list and tuples + """ + if isinstance(ctx, str): + return ctx + + if is_variable(b): + if b in ctx: + b = ctx[b] + else: + ctx[b] = a + return ctx + + if isinstance(a, list) and isinstance(b, list): + return unify(ctx, a[0], b[0]) + elif isinstance(a, tuple) and isinstance(b, tuple) and len(a) == len(b): + old = ctx + for x, y in zip(a, b): + old = unify(old, x, y) + return old + else: + if a == b: + return ctx + else: + return f"{a} != {b}" + +def check_eq(a, b): + unifier = unify({}, a, b) + print(f"{a} <- {b}\n{unifier}\n") + +check_eq('a', 'A') +check_eq(['a'], 'A') +check_eq(['a'], ['A']) +check_eq(['a'], ['b']) +check_eq([('a', 'a', 'b')], ['A']) +check_eq([('a', 'a', 'b')], [('A', 'A', 'B')]) +check_eq([('a', 'a', 'b')], [('A', 'A', 'A')]) +