added some inference

This commit is contained in:
pca006132 2020-12-17 17:00:43 +08:00 committed by pca006132
parent dfc393064e
commit 99801cb1cb
3 changed files with 40 additions and 63 deletions

View File

@ -18,6 +18,7 @@ def parse_type(ctx: Context, ty):
if ty is None:
return None, set()
elif isinstance(ty, ast.Name):
# we should support string either, but no need for toy implementaiton
if ty.id in ctx.types:
return ctx.types[ty.id], set()
elif ty.id in ctx.variables:

View File

@ -1,9 +1,11 @@
class Type:
pass
def __eq__(self, other):
return False
class BotType:
pass
def __eq__(self, other):
return isinstance(other, BotType)
class PrimitiveType(Type):
@ -15,9 +17,15 @@ class PrimitiveType(Type):
def __str__(self):
return self.name
def __eq__(self, other):
return isinstance(other, PrimitiveType) and self.name == other.name
class TypeVariable(Type):
name: str
# this may be a list of str, Type may not be determined when type variables
# are instantiated...
# and they cannot contain other type variables
constraints: list[Type]
def __init__(self, name: str, constraints: list[Type]):
@ -27,6 +35,9 @@ class TypeVariable(Type):
def __str__(self):
return self.name
def __eq__(self, other):
return isinstance(other, TypeVariable) and self.name == other.name
class ClassType(Type):
name: str
@ -43,6 +54,8 @@ class ClassType(Type):
def __str__(self):
return self.name
def __eq__(self, other):
return isinstance(other, ClassType) and self.name == other.name
class VirtualClassType(Type):
base: ClassType
@ -53,22 +66,36 @@ class VirtualClassType(Type):
def __str__(self):
return f"virtual[{self.base}]"
class ListType(Type):
elements: Type
def __eq__(self, other):
return isinstance(other, VirtualClassType) and self.base== other.base
def __init__(self, elements: Type):
self.elements = elements
class ParametricType(Type):
params: list[Type]
def __init__(self, params: list[Type]):
self.params = params
def __eq__(self, other):
if type(self) != type(other) or len(self.params) != len(other.params):
return False
for x, y in zip(self.params, other.params):
if x != y:
return False
return True
class ListType(ParametricType):
def __init__(self, param: Type):
super().__init__([param])
def __str__(self):
return f"list[{self.elements}]"
return f"list[{self.params[0]}]"
class TupleType(Type):
elements: list[Type]
def __init__(self, elements: Type):
self.elements = elements
class TupleType(ParametricType):
def __init__(self, params: list[Type]):
super().__init__(params)
def __str__(self):
return f"tuple[{', '.join([str(v) for v in self.elements])}]"
return f"tuple[{', '.join([str(v) for v in self.params])}]"

View File

@ -1,51 +0,0 @@
def is_variable(b):
return isinstance(b, str) and b.isupper()
def unify(ctx, a, b):
"""
a is the more specific type
b is the type with parameter
lower case means primitive type
upper case means type variable
list and tuples are just list and tuples
"""
if isinstance(ctx, str):
return ctx
if is_variable(b):
if b in ctx:
b = ctx[b]
else:
ctx[b] = a
return ctx
else:
if is_variable(a):
return f"{a} is less specific then {b}"
if isinstance(a, list) and isinstance(b, list):
return unify(ctx, a[0], b[0])
elif isinstance(a, tuple) and isinstance(b, tuple) and len(a) == len(b):
old = ctx
for x, y in zip(a, b):
old = unify(old, x, y)
return old
else:
if a == b:
return ctx
else:
return f"{a} != {b}"
def check_eq(a, b):
unifier = unify({}, a, b)
print(f"{a} <- {b}\n{unifier}\n")
check_eq('a', 'A')
check_eq('A', 'B')
check_eq('A', 'a')
check_eq(['a'], 'A')
check_eq(['a'], ['A'])
check_eq(['a'], ['b'])
check_eq([('a', 'a', 'b')], ['A'])
check_eq([('a', 'a', 'b')], [('A', 'A', 'B')])
check_eq([('a', 'a', 'b')], [('A', 'A', 'A')])