added some inference
This commit is contained in:
parent
99801cb1cb
commit
5f2eb1c10c
58
toy-impl/inference.py
Normal file
58
toy-impl/inference.py
Normal file
@ -0,0 +1,58 @@
|
||||
from type_def import *
|
||||
|
||||
def find_subst(ctx: dict[str, Type],
|
||||
sub: dict[str, Type],
|
||||
a: Type,
|
||||
b: Type):
|
||||
"""
|
||||
Find substitution s such that ctx(a) = s(sub(ctx(b)))
|
||||
return error message if type mismatch
|
||||
"""
|
||||
# is error
|
||||
if isinstance(sub, str):
|
||||
return sub
|
||||
|
||||
if isinstance(a, TypeVariable) and a.name in ctx:
|
||||
a = ctx[a.name]
|
||||
|
||||
if isinstance(b, TypeVariable):
|
||||
if b.name in ctx:
|
||||
b = ctx[b.name]
|
||||
elif b.name in sub:
|
||||
b = sub[b.name]
|
||||
else:
|
||||
if len(b.constraints) > 0:
|
||||
if isinstance(a, TypeVariable):
|
||||
if len(a.constraints) == 0:
|
||||
return f"{b} cannot take value of an unconstrained variable {a}"
|
||||
diff = [v for v in a.constraints if v not in b.constraints]
|
||||
if len(diff) > 0:
|
||||
over = ', '.join([str(v) for v in diff])
|
||||
return f"{b} cannot take value of {a} as {a} can range over [{over}]"
|
||||
else:
|
||||
if a not in b.constraints:
|
||||
return f"{b} cannot take value of {a}"
|
||||
sub[b.name] = a
|
||||
return sub
|
||||
|
||||
if isinstance(a, TypeVariable):
|
||||
return f"{a} can take values other than {b}"
|
||||
|
||||
if isinstance(a, BotType):
|
||||
return sub
|
||||
if type(a) == type(b):
|
||||
if isinstance(a, ParametricType):
|
||||
old = sub
|
||||
for x, y in zip(a.params, b.params):
|
||||
old = find_subst(ctx, old, x, y)
|
||||
return old
|
||||
elif isinstance(a, ClassType) or isinstance(a, PrimitiveType):
|
||||
if a.name == b.name:
|
||||
return sub
|
||||
elif isinstance(a, VirtualClassType):
|
||||
return find_subst(a.base, b.base)
|
||||
else:
|
||||
raise Exception()
|
||||
return f"{a} != {b}"
|
||||
|
||||
|
40
toy-impl/test_subst.py
Normal file
40
toy-impl/test_subst.py
Normal file
@ -0,0 +1,40 @@
|
||||
from type_def import *
|
||||
from inference import *
|
||||
|
||||
types = {
|
||||
'int32': PrimitiveType('int32'),
|
||||
'int64': PrimitiveType('int64'),
|
||||
'str': PrimitiveType('str'),
|
||||
}
|
||||
|
||||
variables = {
|
||||
'X': TypeVariable('X', []),
|
||||
'Y': TypeVariable('Y', []),
|
||||
'I': TypeVariable('I', [types['int32'], types['int64']]),
|
||||
'A': TypeVariable('A', [types['int32'], types['int64'], types['str']]),
|
||||
}
|
||||
|
||||
def stringify_subst(subst):
|
||||
if isinstance(subst, str):
|
||||
return subst
|
||||
elements = [f"{key}: {str(value)}" for key, value in subst.items()]
|
||||
return "{" + ', '.join(elements) + "}"
|
||||
|
||||
def try_case(a, b, ctx):
|
||||
result = find_subst(ctx, {}, a, b)
|
||||
print(f"{a} <- {b} w.r.t. {stringify_subst(ctx)}\n {stringify_subst(result)}\n")
|
||||
|
||||
|
||||
try_case(types['int32'], types['int32'], {})
|
||||
try_case(types['int32'], types['int64'], {})
|
||||
try_case(types['int32'], variables['X'], {})
|
||||
try_case(types['int32'], variables['X'], {'X': types['int32']})
|
||||
try_case(types['int32'], variables['X'], {'X': types['int64']})
|
||||
try_case(variables['X'], variables['X'], {'X': types['int64']})
|
||||
try_case(variables['X'], variables['Y'], {'Y': types['int64']})
|
||||
try_case(variables['X'], variables['Y'], {'X': types['int64']})
|
||||
try_case(variables['I'], variables['X'], {})
|
||||
try_case(variables['I'], variables['A'], {})
|
||||
try_case(variables['A'], variables['I'], {})
|
||||
try_case(variables['X'], variables['I'], {})
|
||||
|
Loading…
Reference in New Issue
Block a user