from __future__ import annotations from typing import Dict, Mapping, List, Set from enum import Enum from itertools import chain class UnificationError(Exception): def __init__(self, msg): super().__init__(msg) class Type: def find(self): return self def unify(self, _): raise NotImplementedError() def subst(self, _): raise NotImplementedError() def check(self): pass class TVarType(Enum): UNDETERMINED = 1 SEQUENCE = 2 RECORD = 5 TUPLE = 6 LIST = 8 def __le__(self, other): if self.__class__ is other.__class__: return (other.value % self.value) == 0 return NotImplemented def unifier(self, other): if self.__class__ is not other.__class__: raise NotImplementedError() if self <= other: return other elif other <= self: return self else: raise UnificationError(f'cannot unify {self} and {other}') class TVar(Type): next_id = 0 def __init__(self, vrange=None): self.type = TVarType.UNDETERMINED self.rank = 0 self.parent = self self.fields = {} self.range = vrange self.id = TVar.next_id TVar.next_id += 1 def check(self): if self.range is not None: ty = self.find() # maybe we should replace this with explicit eq if ty is not self and ty not in self.range: raise UnificationError( f'{self.id} cannot be substituted by {ty}') def subst(self, mapping: Mapping[int, Type]): # user cannot specify fields... # so this is safe if self.id in mapping: return mapping[self.id] return self def __str__(self): s = self.find() if isinstance(s, TVar): if len(s.fields) > 0: fields = '{' + ', '.join([f'{k}: {v}' for k, v in s.fields.items()]) + '}' else: fields = '' return str(s.id) + fields else: return str(s) def find(self): root = self parent = self.parent while root is not parent and isinstance(parent, TVar): _, parent = root, root.parent = parent, parent.parent if isinstance(parent, TCall): parent = parent.find() return parent def unify(self, other): x = other.find() y = self.find() if x is y: return if isinstance(y, TVar) and isinstance(x, TVar): # unify field type x.type = x.type.unifier(y.type) # unify fields for k, v in y.fields.items(): if k in x.fields: x.fields[k].unify(v) else: x.fields[k] = v # standard union find if x.rank < y.rank: x, y = y, x y.parent = x if x.rank == y.rank: x.rank += 1 elif isinstance(y, TVar): # check fields if isinstance(x, TVirtual): if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): if k not in x.obj.fields: raise UnificationError( f'Cannot unify {y} with {x}') u = x.obj.fields[k] v.unify(u) elif isinstance(x, TObj): if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): if k not in x.fields: raise UnificationError( f'Cannot unify {y} with {x}') u = x.fields[k] v.unify(u) elif isinstance(x, TList): if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): assert isinstance(k, int) v.unify(x.param) elif isinstance(x, TTuple): if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): assert isinstance(k, int) if k >= len(x.params): raise UnificationError(f'Cannot unify {y} with {x}') v.unify(x.params[k]) y.parent = x else: y.unify(x) def __eq__(self, other): s = self.find() o = other.find() if not isinstance(s, TVar): return s == o if isinstance(o, TVar): return s.id == o.id return False class FuncArg: def __init__(self, name, typ, is_optional): self.name = name self.typ = typ self.is_optional = is_optional def __str__(self): return f'{self.name}: {self.typ}' + ('?' if self.is_optional else '') class TCall(Type): def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type): self.calls = [[posargs, kwargs, ret, None]] self.parent = self self.rank = 0 def check(self): self.calls[0][3].check() def find(self): root = self parent = self.parent while root is not parent and isinstance(parent, TCall): _, parent = root, root.parent = parent, parent.parent if parent.calls[0][3] is None: return parent return parent.calls[0][3] def unify(self, other): y = self.find() if y is not self: y.unify(other) return x = other.find() if x is y: return if isinstance(x, TCall): # standard union find if x.rank < y.rank: x, y = y, x y.parent = x if x.rank == y.rank: x.rank += 1 x.calls += y.calls elif isinstance(x, TFunc): fn = x for i in range(len(y.calls)): if not x.instantiated: fn = x.instantiate() posargs, kwargs, ret, _ = y.calls[i] c = y.calls[i] c[3] = fn all_args = set(arg.name for arg in fn.args) required = set(arg.name for arg in fn.args if not arg.is_optional) fn.ret.unify(ret) for i, v in enumerate(posargs): arg = fn.args[i] arg.typ.unify(v) if arg.name in required: required.remove(arg.name) for k, v in kwargs.items(): arg = next((arg for arg in fn.args if arg.name == k), None) if arg is None: raise UnificationError(f'Unknown kwarg {k}') if k not in all_args: raise UnificationError(f'Duplicated kwarg {k}') arg.typ.unify(v) if k in required: required.remove(k) all_args.remove(k) if len(required) > 0: raise UnificationError(f'Missing arguments') elif isinstance(other, TVar): other.unify(self) else: raise UnificationError(f'Cannot unify a call with {other}') class TFunc(Type): def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]): self.args = args self.ret = ret self.vars = vars self.instantiated = False def check(self): for arg in self.args: arg.typ.check() self.ret.check() def subst(self, mapping: Mapping[int, Type]): if len(mapping) == 0: return self return TFunc( [FuncArg(arg.name, arg.typ.subst(mapping), arg.is_optional) for arg in self.args], self.ret.subst(mapping), self.vars) def instantiate(self): mapping = {v.id: TVar(v.range) if isinstance(v, TVar) else TVar() for v in self.vars} result = self.subst(mapping) result.instantiated = True return result def __str__(self): return f'({", ".join(str(arg) for arg in self.args)}) -> {self.ret}' def unify(self, other): other = other.find() if other is self: return if isinstance(other, (TVar, TCall)): other.unify(self) elif isinstance(other, TFunc): if len(self.args) != len(other.args): raise UnificationError( f'cannot unify functions with different parameters') self.ret.unify(other.ret) for a, b in zip(self.args, other.args): if a.name != b.name or a.is_optional != b.is_optional: raise UnificationError( f'cannot unify functions with different parameters') a.typ.unify(b.typ) else: raise UnificationError(f'Cannot unify a function with {other}') class TObj(Type): def __init__(self, name: str, fields: Dict[str, Type], params: List[Type], parents=None): self.name = name self.fields = fields self.params = params if parents is None: self.parents = [] else: self.parents = parents def check(self): for arg in self.fields.values(): arg.check() def subst(self, mapping: Mapping[int, Type]): if len(mapping) == 0: return self new_params = [] for v in self.params: if isinstance(v, TVar) and v.id in mapping: new_params.append(mapping[v.id]) else: new_params.append(v) return TObj(self.name, {k: v.subst(mapping) for k, v in self.fields.items()}, new_params) def unify(self, other): other = other.find() if other is self: return if isinstance(other, TObj): if self.name != other.name: raise UnificationError(f'Cannot unify {self} with {other}') for k in self.fields: self.fields[k].unify(other.fields[k]) elif isinstance(other, TVar): other.unify(self) else: raise UnificationError(f'Cannot unify an object with {other}') def __str__(self): if len(self.params) > 0: p = '[' + ', '.join(str(p) for p in self.params) + ']' else: p = '' return self.name + p def __eq__(self, other): o = other.find() if isinstance(o, TObj): if self.name != o.name: return False for a, b in zip(self.params, o.params): if a != b: return False return True return False class TVirtual(Type): def __init__(self, obj: TObj): self.obj = obj def __eq__(self, other): o = other.find() if isinstance(o, TVirtual): return self == o return False def unify(self, other): o = other.find() if isinstance(o, TVirtual): self.obj.unify(o.obj) else: raise UnificationError(f'Cannot unify {self} with {o}') def __str__(self): return f'virtual[{self.obj}]' class TList(Type): def __init__(self, param: Type): self.param = param def check(self): self.param.check() def unify(self, other): other = other.find() if isinstance(other, TVar): other.unify(self) elif isinstance(other, TList): self.param.unify(other.param) else: raise UnificationError(f'Cannot unify list with {other}') def __str__(self): return f'List[{self.param}]' def __eq__(self, other): o = other.find() if isinstance(o, TList): return self.param == o.param return False class TTuple(Type): def __init__(self, params: List[Type]): self.params = params def check(self): for p in self.params: p.check() def unify(self, other): other = other.find() if isinstance(other, TVar): other.unify(self) elif isinstance(other, TTuple): if len(self.params) != len(other.params): raise UnificationError( 'cannot unify tuples with different length') for a, b in zip(self.params, other.params): a.unify(b) else: raise UnificationError(f'Cannot unify {self} with {other}') def __str__(self): return f'Tuple[{", ".join(str(p) for p in self.params)}]' def __eq__(self, other): o = other.find() if isinstance(o, TTuple): if len(self.params) != len(o.params): return False for a, b in zip(self.params, o.params): if a != b: return False return True return False TBool = TObj('bool', {}, []) TInt = TObj('int', {}, [])