added some inference
This commit is contained in:
parent
dfc393064e
commit
99801cb1cb
@ -18,6 +18,7 @@ def parse_type(ctx: Context, ty):
|
||||
if ty is None:
|
||||
return None, set()
|
||||
elif isinstance(ty, ast.Name):
|
||||
# we should support string either, but no need for toy implementaiton
|
||||
if ty.id in ctx.types:
|
||||
return ctx.types[ty.id], set()
|
||||
elif ty.id in ctx.variables:
|
||||
|
@ -1,9 +1,11 @@
|
||||
class Type:
|
||||
pass
|
||||
def __eq__(self, other):
|
||||
return False
|
||||
|
||||
|
||||
class BotType:
|
||||
pass
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, BotType)
|
||||
|
||||
|
||||
class PrimitiveType(Type):
|
||||
@ -15,9 +17,15 @@ class PrimitiveType(Type):
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, PrimitiveType) and self.name == other.name
|
||||
|
||||
|
||||
class TypeVariable(Type):
|
||||
name: str
|
||||
# this may be a list of str, Type may not be determined when type variables
|
||||
# are instantiated...
|
||||
# and they cannot contain other type variables
|
||||
constraints: list[Type]
|
||||
|
||||
def __init__(self, name: str, constraints: list[Type]):
|
||||
@ -27,6 +35,9 @@ class TypeVariable(Type):
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TypeVariable) and self.name == other.name
|
||||
|
||||
|
||||
class ClassType(Type):
|
||||
name: str
|
||||
@ -43,6 +54,8 @@ class ClassType(Type):
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ClassType) and self.name == other.name
|
||||
|
||||
class VirtualClassType(Type):
|
||||
base: ClassType
|
||||
@ -53,22 +66,36 @@ class VirtualClassType(Type):
|
||||
def __str__(self):
|
||||
return f"virtual[{self.base}]"
|
||||
|
||||
class ListType(Type):
|
||||
elements: Type
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, VirtualClassType) and self.base== other.base
|
||||
|
||||
def __init__(self, elements: Type):
|
||||
self.elements = elements
|
||||
class ParametricType(Type):
|
||||
params: list[Type]
|
||||
|
||||
def __init__(self, params: list[Type]):
|
||||
self.params = params
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) != type(other) or len(self.params) != len(other.params):
|
||||
return False
|
||||
for x, y in zip(self.params, other.params):
|
||||
if x != y:
|
||||
return False
|
||||
return True
|
||||
|
||||
class ListType(ParametricType):
|
||||
def __init__(self, param: Type):
|
||||
super().__init__([param])
|
||||
|
||||
def __str__(self):
|
||||
return f"list[{self.elements}]"
|
||||
return f"list[{self.params[0]}]"
|
||||
|
||||
class TupleType(Type):
|
||||
elements: list[Type]
|
||||
|
||||
def __init__(self, elements: Type):
|
||||
self.elements = elements
|
||||
class TupleType(ParametricType):
|
||||
def __init__(self, params: list[Type]):
|
||||
super().__init__(params)
|
||||
|
||||
def __str__(self):
|
||||
return f"tuple[{', '.join([str(v) for v in self.elements])}]"
|
||||
return f"tuple[{', '.join([str(v) for v in self.params])}]"
|
||||
|
||||
|
||||
|
@ -1,51 +0,0 @@
|
||||
def is_variable(b):
|
||||
return isinstance(b, str) and b.isupper()
|
||||
|
||||
def unify(ctx, a, b):
|
||||
"""
|
||||
a is the more specific type
|
||||
b is the type with parameter
|
||||
lower case means primitive type
|
||||
upper case means type variable
|
||||
list and tuples are just list and tuples
|
||||
"""
|
||||
if isinstance(ctx, str):
|
||||
return ctx
|
||||
|
||||
if is_variable(b):
|
||||
if b in ctx:
|
||||
b = ctx[b]
|
||||
else:
|
||||
ctx[b] = a
|
||||
return ctx
|
||||
else:
|
||||
if is_variable(a):
|
||||
return f"{a} is less specific then {b}"
|
||||
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
return unify(ctx, a[0], b[0])
|
||||
elif isinstance(a, tuple) and isinstance(b, tuple) and len(a) == len(b):
|
||||
old = ctx
|
||||
for x, y in zip(a, b):
|
||||
old = unify(old, x, y)
|
||||
return old
|
||||
else:
|
||||
if a == b:
|
||||
return ctx
|
||||
else:
|
||||
return f"{a} != {b}"
|
||||
|
||||
def check_eq(a, b):
|
||||
unifier = unify({}, a, b)
|
||||
print(f"{a} <- {b}\n{unifier}\n")
|
||||
|
||||
check_eq('a', 'A')
|
||||
check_eq('A', 'B')
|
||||
check_eq('A', 'a')
|
||||
check_eq(['a'], 'A')
|
||||
check_eq(['a'], ['A'])
|
||||
check_eq(['a'], ['b'])
|
||||
check_eq([('a', 'a', 'b')], ['A'])
|
||||
check_eq([('a', 'a', 'b')], [('A', 'A', 'B')])
|
||||
check_eq([('a', 'a', 'b')], [('A', 'A', 'A')])
|
||||
|
Loading…
Reference in New Issue
Block a user