nac3-spec/hm-inference/nac3_types.py

372 lines
11 KiB
Python
Raw Normal View History

2021-07-09 15:27:02 +08:00
from __future__ import annotations
from typing import Dict, Mapping, List, Set
from enum import Enum
from itertools import chain
class UnificationError(Exception):
def __init__(self, msg):
super().__init__(msg)
class Type:
def find(self):
return self
def unify(self, _):
raise NotImplementedError()
def subst(self, _):
raise NotImplementedError()
def check(self):
pass
class TVarType(Enum):
UNDETERMINED = 1
SEQUENCE = 2
RECORD = 5
TUPLE = 6
LIST = 8
def __le__(self, other):
if self.__class__ is other.__class__:
return (other.value % self.value) == 0
return NotImplemented
def unifier(self, other):
if self.__class__ is not other.__class__:
raise NotImplementedError()
if self <= other:
return other
elif other <= self:
return self
else:
raise UnificationError(f'cannot unify {self} and {other}')
class TVar(Type):
next_id = 0
def __init__(self, vrange=None):
self.type = TVarType.UNDETERMINED
self.rank = 0
self.parent = self
self.fields = {}
self.range = vrange
self.id = TVar.next_id
TVar.next_id += 1
def check(self):
if self.range is not None:
ty = self.find()
# maybe we should replace this with explicit eq
2021-07-09 16:06:06 +08:00
if ty is not self and ty not in self.range:
2021-07-09 15:27:02 +08:00
raise UnificationError(
f'{self.id} cannot be substituted by {ty}')
def subst(self, mapping: Mapping[int, Type]):
# user cannot specify fields...
# so this is safe
if self.id in mapping:
return mapping[self.id]
return self
def __str__(self):
s = self.find()
if isinstance(s, TVar):
if len(s.fields) > 0:
fields = '{' + ', '.join([f'{k}: {v}' for k,
v in s.fields.items()]) + '}'
else:
fields = ''
return str(s.id) + fields
else:
return str(s)
def find(self):
root = self
parent = self.parent
while root is not parent and isinstance(parent, TVar):
_, parent = root, root.parent = parent, parent.parent
if isinstance(parent, TCall):
parent = parent.find()
return parent
def unify(self, other):
x = other.find()
y = self.find()
if x is y:
return
if isinstance(y, TVar) and isinstance(x, TVar):
# unify field type
x.type = x.type.unifier(y.type)
# unify fields
for k, v in y.fields.items():
if k in x.fields:
x.fields[k].unify(v)
else:
x.fields[k] = v
# standard union find
if x.rank < y.rank:
x, y = y, x
y.parent = x
if x.rank == y.rank:
x.rank += 1
elif isinstance(y, TVar):
# check fields
if isinstance(x, TObj):
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
if k not in x.fields:
raise UnificationError(
f'Cannot unify {y} with {x}')
2021-07-09 16:06:06 +08:00
if isinstance(v, TFunc) and not v.instantiated:
2021-07-09 15:27:02 +08:00
v = v.instantiate()
u = x.fields[k]
2021-07-09 16:06:06 +08:00
if isinstance(u, TFunc) and not u.instantiated:
2021-07-09 15:27:02 +08:00
u = u.instantiate()
v.unify(u)
if isinstance(x, TList):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
assert isinstance(k, int)
v.unify(x.param)
if isinstance(x, TTuple):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
assert isinstance(k, int)
if k >= len(x.params):
raise UnificationError(f'Cannot unify {y} with {x}')
v.unify(x.params[k])
y.parent = x
else:
y.unify(x)
class FuncArg:
def __init__(self, name, typ, is_optional):
self.name = name
self.typ = typ
self.is_optional = is_optional
def __str__(self):
return f'{self.name}: {self.typ}' + ('?' if self.is_optional else '')
class TCall(Type):
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
self.posargs = posargs
self.kwargs = kwargs
self.ret = ret
self.fun = TVar()
def check(self):
2021-07-09 16:06:06 +08:00
self.fun.find().check()
2021-07-09 15:27:02 +08:00
def find(self):
if isinstance(self.fun.find(), TVar):
return self
return self.fun.find()
def unify(self, other):
if not isinstance(self.fun.find(), TVar):
self.fun.unify(other)
return
other = other.find()
if other is self:
return
if isinstance(other, TCall):
for a, b in zip(self.posargs, other.posargs):
a.unify(b)
for k, v in self.kwargs.items():
if k in other.kwargs:
other.kwargs[k].unify(v)
else:
other.kwargs[k] = v
for k, v in other.kwargs.items():
if k not in self.kwargs:
self.kwargs[k] = v
self.fun.unify(other.fun)
elif isinstance(other, TFunc):
all_args = set(arg.name for arg in other.args)
required = set(arg.name for arg in other.args if not
arg.is_optional)
other.ret.unify(self.ret)
for i, v in enumerate(self.posargs):
arg = other.args[i]
arg.typ.unify(v)
if arg.name in required:
required.remove(arg.name)
for k, v in self.kwargs.items():
arg = next((arg for arg in other.args if arg.name == k), None)
if arg is None:
raise UnificationError(f'Unknown kwarg {k}')
if k not in all_args:
raise UnificationError(f'Duplicated kwarg {k}')
arg.typ.unify(v)
if k in required:
required.remove(k)
all_args.remove(k)
if len(required) > 0:
raise UnificationError(f'Missing arguments')
self.fun.unify(other)
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(f'Cannot unify a call with {other}')
class TFunc(Type):
def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]):
self.args = args
self.ret = ret
self.vars = vars
self.instantiated = False
def check(self):
for arg in self.args:
arg.typ.check()
self.ret.check()
def subst(self, mapping: Mapping[int, Type]):
if len(mapping) == 0:
return self
return TFunc(
[FuncArg(arg.name, arg.typ.subst(mapping), arg.is_optional)
for arg in self.args],
self.ret.subst(mapping),
self.vars)
def instantiate(self):
mapping = {v.id: TVar(v.range) if isinstance(v, TVar)
else TVar() for v in self.vars}
result = self.subst(mapping)
result.instantiated = True
return result
def __str__(self):
return f'({", ".join(str(arg) for arg in self.args)}) -> {self.ret}'
def unify(self, other):
other = other.find()
if other is self:
return
if isinstance(other, (TVar, TCall)):
other.unify(self)
elif isinstance(other, TFunc):
if len(self.args) != len(other.args):
raise UnificationError(
f'cannot unify functions with different parameters')
self.ret.unify(other.ret)
for a, b in zip(self.args, other.args):
if a.name != b.name or a.is_optional != b.is_optional:
raise UnificationError(
f'cannot unify functions with different parameters')
a.typ.unify(b.typ)
else:
raise UnificationError(f'Cannot unify a function with {other}')
class TObj(Type):
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type]):
self.name = name
self.fields = fields
self.params = params
def check(self):
for arg in self.fields.values():
arg.check()
def subst(self, mapping: Mapping[int, Type]):
if len(mapping) == 0:
return self
new_params = []
for v in self.params:
if isinstance(v, TVar) and v.id in mapping:
new_params.append(mapping[v.id])
else:
new_params.append(v)
return TObj(self.name, {k: v.subst(mapping) for k, v in
self.fields.items()}, new_params)
def unify(self, other):
other = other.find()
if other is self:
return
if isinstance(other, TObj):
if self.name != other.name:
raise UnificationError(f'Cannot unify {self} with {other}')
for k in self.fields:
self.fields[k].unify(other.fields[k])
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(f'Cannot unify an object with {other}')
def __str__(self):
if len(self.params) > 0:
p = '[' + ', '.join(str(p) for p in self.params) + ']'
else:
p = ''
return self.name + p
class TList(Type):
def __init__(self, param: Type):
self.param = param
2021-07-09 16:06:06 +08:00
def check(self):
self.param.check()
2021-07-09 15:27:02 +08:00
def unify(self, other):
other = other.find()
if isinstance(other, TVar):
other.unify(self)
elif isinstance(other, TList):
self.param.unify(other.param)
else:
raise UnificationError(f'Cannot unify list with {other}')
def __str__(self):
return f'List[{self.param}]'
class TTuple(Type):
def __init__(self, params: List[Type]):
self.params = params
2021-07-09 16:06:06 +08:00
def check(self):
for p in self.params:
p.check()
2021-07-09 15:27:02 +08:00
def unify(self, other):
other = other.find()
if isinstance(other, TVar):
other.unify(self)
elif isinstance(other, TTuple):
if len(self.params) != len(other.params):
raise UnificationError(
'cannot unify tuples with different length')
for a, b in zip(self.params, other.params):
a.unify(b)
else:
raise UnificationError(f'Cannot unify {self} with {other}')
def __str__(self):
return f'Tuple[{", ".join(str(p) for p in self.params)}]'
TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, [])