nac3-spec/type_check.py

47 lines
1.1 KiB
Python
Raw Normal View History

2020-12-16 17:13:43 +08:00
def is_variable(b):
return isinstance(b, str) and b.isupper()
def unify(ctx, a, b):
"""
a is the concrete type
b is the type with parameter
lower case means primitive type
upper case means type variable
list and tuples are just list and tuples
"""
if isinstance(ctx, str):
return ctx
if is_variable(b):
if b in ctx:
b = ctx[b]
else:
ctx[b] = a
return ctx
if isinstance(a, list) and isinstance(b, list):
return unify(ctx, a[0], b[0])
elif isinstance(a, tuple) and isinstance(b, tuple) and len(a) == len(b):
old = ctx
for x, y in zip(a, b):
old = unify(old, x, y)
return old
else:
if a == b:
return ctx
else:
return f"{a} != {b}"
def check_eq(a, b):
unifier = unify({}, a, b)
print(f"{a} <- {b}\n{unifier}\n")
check_eq('a', 'A')
check_eq(['a'], 'A')
check_eq(['a'], ['A'])
check_eq(['a'], ['b'])
check_eq([('a', 'a', 'b')], ['A'])
check_eq([('a', 'a', 'b')], [('A', 'A', 'B')])
check_eq([('a', 'a', 'b')], [('A', 'A', 'A')])