added some inference
This commit is contained in:
parent
dfc393064e
commit
99801cb1cb
@ -18,6 +18,7 @@ def parse_type(ctx: Context, ty):
|
|||||||
if ty is None:
|
if ty is None:
|
||||||
return None, set()
|
return None, set()
|
||||||
elif isinstance(ty, ast.Name):
|
elif isinstance(ty, ast.Name):
|
||||||
|
# we should support string either, but no need for toy implementaiton
|
||||||
if ty.id in ctx.types:
|
if ty.id in ctx.types:
|
||||||
return ctx.types[ty.id], set()
|
return ctx.types[ty.id], set()
|
||||||
elif ty.id in ctx.variables:
|
elif ty.id in ctx.variables:
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
class Type:
|
class Type:
|
||||||
pass
|
def __eq__(self, other):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class BotType:
|
class BotType:
|
||||||
pass
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, BotType)
|
||||||
|
|
||||||
|
|
||||||
class PrimitiveType(Type):
|
class PrimitiveType(Type):
|
||||||
@ -15,9 +17,15 @@ class PrimitiveType(Type):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, PrimitiveType) and self.name == other.name
|
||||||
|
|
||||||
|
|
||||||
class TypeVariable(Type):
|
class TypeVariable(Type):
|
||||||
name: str
|
name: str
|
||||||
|
# this may be a list of str, Type may not be determined when type variables
|
||||||
|
# are instantiated...
|
||||||
|
# and they cannot contain other type variables
|
||||||
constraints: list[Type]
|
constraints: list[Type]
|
||||||
|
|
||||||
def __init__(self, name: str, constraints: list[Type]):
|
def __init__(self, name: str, constraints: list[Type]):
|
||||||
@ -27,6 +35,9 @@ class TypeVariable(Type):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, TypeVariable) and self.name == other.name
|
||||||
|
|
||||||
|
|
||||||
class ClassType(Type):
|
class ClassType(Type):
|
||||||
name: str
|
name: str
|
||||||
@ -43,6 +54,8 @@ class ClassType(Type):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, ClassType) and self.name == other.name
|
||||||
|
|
||||||
class VirtualClassType(Type):
|
class VirtualClassType(Type):
|
||||||
base: ClassType
|
base: ClassType
|
||||||
@ -53,22 +66,36 @@ class VirtualClassType(Type):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"virtual[{self.base}]"
|
return f"virtual[{self.base}]"
|
||||||
|
|
||||||
class ListType(Type):
|
def __eq__(self, other):
|
||||||
elements: Type
|
return isinstance(other, VirtualClassType) and self.base== other.base
|
||||||
|
|
||||||
def __init__(self, elements: Type):
|
class ParametricType(Type):
|
||||||
self.elements = elements
|
params: list[Type]
|
||||||
|
|
||||||
|
def __init__(self, params: list[Type]):
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if type(self) != type(other) or len(self.params) != len(other.params):
|
||||||
|
return False
|
||||||
|
for x, y in zip(self.params, other.params):
|
||||||
|
if x != y:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
class ListType(ParametricType):
|
||||||
|
def __init__(self, param: Type):
|
||||||
|
super().__init__([param])
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"list[{self.elements}]"
|
return f"list[{self.params[0]}]"
|
||||||
|
|
||||||
class TupleType(Type):
|
|
||||||
elements: list[Type]
|
|
||||||
|
|
||||||
def __init__(self, elements: Type):
|
class TupleType(ParametricType):
|
||||||
self.elements = elements
|
def __init__(self, params: list[Type]):
|
||||||
|
super().__init__(params)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"tuple[{', '.join([str(v) for v in self.elements])}]"
|
return f"tuple[{', '.join([str(v) for v in self.params])}]"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
def is_variable(b):
|
|
||||||
return isinstance(b, str) and b.isupper()
|
|
||||||
|
|
||||||
def unify(ctx, a, b):
|
|
||||||
"""
|
|
||||||
a is the more specific type
|
|
||||||
b is the type with parameter
|
|
||||||
lower case means primitive type
|
|
||||||
upper case means type variable
|
|
||||||
list and tuples are just list and tuples
|
|
||||||
"""
|
|
||||||
if isinstance(ctx, str):
|
|
||||||
return ctx
|
|
||||||
|
|
||||||
if is_variable(b):
|
|
||||||
if b in ctx:
|
|
||||||
b = ctx[b]
|
|
||||||
else:
|
|
||||||
ctx[b] = a
|
|
||||||
return ctx
|
|
||||||
else:
|
|
||||||
if is_variable(a):
|
|
||||||
return f"{a} is less specific then {b}"
|
|
||||||
|
|
||||||
if isinstance(a, list) and isinstance(b, list):
|
|
||||||
return unify(ctx, a[0], b[0])
|
|
||||||
elif isinstance(a, tuple) and isinstance(b, tuple) and len(a) == len(b):
|
|
||||||
old = ctx
|
|
||||||
for x, y in zip(a, b):
|
|
||||||
old = unify(old, x, y)
|
|
||||||
return old
|
|
||||||
else:
|
|
||||||
if a == b:
|
|
||||||
return ctx
|
|
||||||
else:
|
|
||||||
return f"{a} != {b}"
|
|
||||||
|
|
||||||
def check_eq(a, b):
|
|
||||||
unifier = unify({}, a, b)
|
|
||||||
print(f"{a} <- {b}\n{unifier}\n")
|
|
||||||
|
|
||||||
check_eq('a', 'A')
|
|
||||||
check_eq('A', 'B')
|
|
||||||
check_eq('A', 'a')
|
|
||||||
check_eq(['a'], 'A')
|
|
||||||
check_eq(['a'], ['A'])
|
|
||||||
check_eq(['a'], ['b'])
|
|
||||||
check_eq([('a', 'a', 'b')], ['A'])
|
|
||||||
check_eq([('a', 'a', 'b')], [('A', 'A', 'B')])
|
|
||||||
check_eq([('a', 'a', 'b')], [('A', 'A', 'A')])
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user