Compare commits

...

11 Commits
master ... pca

Author SHA1 Message Date
pca006132 39b3faba6e fixed for loop unification 2021-07-12 10:36:52 +08:00
pca006132 1ad21f0d67 optional virtual type annotation 2021-07-12 09:35:01 +08:00
pca006132 dbf9c17d9f added assumption 2021-07-10 16:23:07 +08:00
pca006132 0010e5852a added readme 2021-07-10 16:21:30 +08:00
pca006132 902e1a892c implemented basic operations 2021-07-10 15:45:53 +08:00
pca006132 b1020352ce allows recursive type, implementing primitives 2021-07-10 15:11:27 +08:00
pca006132 66df55b3d7 virtual type 2021-07-10 14:36:28 +08:00
pca006132 0c8029d7e1 fixed minor bug 2021-07-09 17:31:54 +08:00
pca006132 e15e29d673 fixed polymorphic methods 2021-07-09 17:11:34 +08:00
pca006132 1c17ed003e fixed type var check 2021-07-09 16:06:06 +08:00
pca006132 59628cfa38 init hm-inference 2021-07-09 15:27:02 +08:00
5 changed files with 873 additions and 0 deletions

50
hm-inference/README.md Normal file
View File

@ -0,0 +1,50 @@
# Type Inference
This is a prototype before Rust implementation of the algorithm.
## Implemented Features
- Primitive types: bool, float, integer. (limited magic methods are implemented,
no casting for now)
- Statements:
- For loop.
- While loop.
- If expression.
- Simple assignment. (without type annotation, allows pattern matcing for
tuples)
- Expressions:
- Boolean operations.
- Binary operations. (+, -, *, /, //)
- Compare.
- Function call. (posargs/kwargs/optional are supported)
- Lambda.
- Object Attribute.
- List/tuple subscript. (tuple requires constant indexing)
- Virtual.
- Constraints for type variables.
User can define functions/types by adding them to the prelude.
Note that variables can be used before definition for now, we would do another
pass after type checking to prevent this. The pass would also check for return.
We assume that class type do not contain any unbound type variables. (methods
can contain unbound type variables, if the variable is bound to the class)
## TODO
- Parse class/function definition.
- Occur check to prevent infinite function types.
- Pretty print for types. (especially type variables)
- Better error message. (we did not keep any context for now)
## Implementation Notes
- Type variables would now retain a lot of information, including fields and its
range for later checking.
- Function type is seperated from call type. We maintain a list of calls to the
same function (potentially with different type signature) to allow different
instantiation of the same function.
- We store the list of calls and the list of virtual types, and check the
constraints at the end of the type inference phase. E.g. check subtyping
relationship, check whether the type variable can be instantiated to a certain
type.

237
hm-inference/ast_visitor.py Normal file
View File

@ -0,0 +1,237 @@
import ast
from itertools import chain
from nac3_types import *
from primitives import *
def get_magic_method(op):
if isinstance(op, ast.Add):
return '__add__'
if isinstance(op, ast.Sub):
return '__sub__'
if isinstance(op, ast.Mult):
return '__mul__'
if isinstance(op, ast.Div):
return '__truediv__'
if isinstance(op, ast.FloorDiv):
return '__floordiv__'
if isinstance(op, ast.Eq):
return '__eq__'
if isinstance(op, ast.NotEq):
return '__ne__'
if isinstance(op, ast.Lt):
return '__lt__'
if isinstance(op, ast.LtE):
return '__le__'
if isinstance(op, ast.Gt):
return '__gt__'
if isinstance(op, ast.GtE):
return '__ge__'
raise Exception
class Visitor(ast.NodeVisitor):
def __init__(self, src, assignments, type_parser):
super(Visitor, self).__init__()
self.source = src
self.assignments = assignments
self.calls = []
self.virtuals = []
self.type_parser = type_parser
def visit_Assign(self, node):
self.visit(node.value)
for target in node.targets:
self.visit(target)
target.type.unify(node.value.type)
def visit_Name(self, node):
if node.id == '_':
node.type = TVar()
return
if node.id not in self.assignments:
self.assignments[node.id] = TVar()
ty = self.assignments[node.id]
if isinstance(ty, TFunc):
ty = ty.instantiate()
node.type = ty
def visit_Lambda(self, node):
old = self.assignments
self.assignments = old.copy()
self.visit(node.args)
self.visit(node.body)
self.assignments = old
node.type = TFunc(node.args.type, node.body.type, [])
def visit_arguments(self, node):
for arg in node.args:
self.assignments[arg.arg] = TVar()
arg.type = self.assignments[arg.arg]
node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args]
def visit_Constant(self, node):
if isinstance(node.value, bool):
node.type = TBool
elif isinstance(node.value, int):
node.type = TInt
elif isinstance(node.value, float):
node.type = TFloat
def visit_Call(self, node):
if ast.get_source_segment(self.source, node.func) == 'virtual':
if len(node.args) > 2 or len(node.args) < 1:
raise UnificationError('Incorrect argument number for virtual')
self.visit(node.args[0])
if len(node.args) == 2:
ty = self.type_parser(ast.get_source_segment(self.source,
node.args[1]))
else:
ty = TVar()
self.virtuals.append((node.args[0].type, ty))
node.type = TVirtual(ty)
return
self.visit(node.func)
for arg in node.args:
self.visit(arg)
for keyword in node.keywords:
self.visit(keyword.value)
node.type = TVar()
call = TCall([arg.type for arg in node.args],
{keyword.arg: keyword.value.type for keyword
in node.keywords}, node.type)
node.func.type.unify(call)
self.calls.append(call)
def visit_Attribute(self, node):
self.visit(node.value)
node.type = TVar()
v = TVar()
v.type = TVarType.RECORD
v.fields[node.attr] = node.type
node.value.type.unify(v)
def visit_Tuple(self, node):
for e in node.elts:
self.visit(e)
node.type = TTuple([e.type for e in node.elts])
def visit_List(self, node):
ty = TVar()
for e in node.elts:
self.visit(e)
ty.unify(e.type)
node.type = TList(ty)
def visit_Subscript(self, node):
self.visit(node.value)
if isinstance(node.slice, ast.Slice):
node.type = node.value.type
if isinstance(node.type, TVar):
node.type.type = node.type.type.unifier(TVarType.LIST)
elif not isinstance(node.type, TList):
raise UnificationError(f'{node.type} should be a list')
elif isinstance(node.slice, ast.ExtSlice):
raise NotImplementedError()
else:
# complicated because we need to handle lists and tuples
# differently...
self.visit(node.slice.value)
node.slice.value.type.unify(TInt)
ty = node.value.type
node.type = TVar()
if isinstance(ty, TVar):
index = 0
if isinstance(node.slice.value, ast.Constant):
seq_ty = TVarType.SEQUENCE
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
seq_ty = TVarType.LIST
ty.type = ty.type.unifier(seq_ty)
index = node.slice.value.value
else:
ty.type = ty.type.unifier(TVarType.LIST)
if index in ty.fields:
ty.fields[index].unify(node.type)
else:
ty.fields[index] = node.type
elif isinstance(ty, TList):
ty.param.unify(node.type)
elif isinstance(ty, TTuple):
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
raise UnificationError(f'Cannot assign to tuple')
if isinstance(node.slice.value, ast.Constant):
index = node.slice.value.value
if index >= len(ty.params):
raise UnificationError('Index out of range for tuple')
ty.params[index].unify(node.type)
else:
raise UnificationError('Tuple index must be a constant')
else:
raise UnificationError(f'Cannot use subscript for {ty}')
def visit_For(self, node):
self.visit(node.target)
self.visit(node.iter)
ty = node.iter.type.find()
if isinstance(ty, TVar):
# we currently only support iterator over lists
ty.type = ty.type.unifier(TVarType.LIST)
if 0 in ty.fields:
ty.fields[0].unify(node.target.type)
else:
ty.fields[0] = node.target.type
elif isinstance(ty, TList):
ty.param.unify(node.target.type)
else:
raise UnificationError(f'Cannot iterate over {ty}')
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_If(self, node):
self.visit(node.test)
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_While(self, node):
self.visit(node.test)
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_BoolOp(self, node):
self.visit(node.values[0])
self.visit(node.values[1])
node.values[0].type.unify(TBool)
node.values[1].type.unify(TBool)
node.type = TBool
def visit_BinOp(self, node):
self.visit(node.left)
self.visit(node.right)
# call method...
method = get_magic_method(node.op)
ret = TVar()
node.type = ret
call = TCall([node.right.type], {}, ret)
self.calls.append(call)
v = TVar()
v.type = TVarType.RECORD
v.fields[method] = call
node.left.type.unify(v)
def visit_Compare(self, node):
self.visit(node.left)
for c in node.comparators:
self.visit(c)
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
node.comparators, node.ops):
method = get_magic_method(c)
call = TCall([b.type], {}, TBool)
self.calls.append(call)
v = TVar()
v.type = TVarType.RECORD
v.fields[method] = call
a.type.unify(v)
node.type = TBool

472
hm-inference/nac3_types.py Normal file
View File

@ -0,0 +1,472 @@
from __future__ import annotations
from typing import Dict, Mapping, List, Set
from enum import Enum
class UnificationError(Exception):
def __init__(self, msg):
super().__init__(msg)
class Type:
def find(self):
return self
def unify(self, _):
raise NotImplementedError()
def subst(self, _):
raise NotImplementedError()
def check(self):
pass
class TVarType(Enum):
UNDETERMINED = 1
SEQUENCE = 2
RECORD = 5
TUPLE = 6
LIST = 8
def __le__(self, other):
if self.__class__ is other.__class__:
return (other.value % self.value) == 0
return NotImplemented
def unifier(self, other):
if self.__class__ is not other.__class__:
raise NotImplementedError()
if self <= other:
return other
elif other <= self:
return self
else:
raise UnificationError(f'cannot unify {self} and {other}')
class TVar(Type):
next_id = 0
def __init__(self, vrange=None):
self.type = TVarType.UNDETERMINED
self.rank = 0
self.parent = self
self.checked = False
self.fields = {}
self.range = vrange
self.id = TVar.next_id
TVar.next_id += 1
def check(self):
if self.checked:
return
self.checked = True
if self.range is not None:
ty = self.find()
# maybe we should replace this with explicit eq
if ty is not self and ty not in self.range:
raise UnificationError(
f'{self.id} cannot be substituted by {ty}')
def subst(self, mapping: Mapping[int, Type]):
# user cannot specify fields...
# so this is safe
if self.id in mapping:
return mapping[self.id]
return self
def __str__(self):
s = self.find()
if isinstance(s, TVar):
if len(s.fields) > 0:
fields = '{' + ', '.join([f'{k}: {v}' for k,
v in s.fields.items()]) + '}'
else:
fields = ''
return str(s.id) + fields
else:
return str(s)
def find(self):
root = self
parent = self.parent
while root is not parent and isinstance(parent, TVar):
_, parent = root, root.parent = parent, parent.parent
if isinstance(parent, TCall):
parent = parent.find()
return parent
def unify(self, other):
x = other.find()
y = self.find()
if x is y:
return
if isinstance(y, TVar) and isinstance(x, TVar):
# unify field type
x.type = x.type.unifier(y.type)
# unify fields
for k, v in y.fields.items():
if k in x.fields:
x.fields[k].unify(v)
else:
x.fields[k] = v
# standard union find
if x.rank < y.rank:
x, y = y, x
y.parent = x
if x.rank == y.rank:
x.rank += 1
elif isinstance(y, TVar):
# check fields
if isinstance(x, TVirtual):
obj = x.obj.find()
self.unify(obj)
elif isinstance(x, TObj):
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
if k not in x.fields:
raise UnificationError(
f'Cannot unify {y} with {x}')
u = x.fields[k]
v.unify(u)
elif isinstance(x, TList):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
assert isinstance(k, int)
v.unify(x.param)
elif isinstance(x, TTuple):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
assert isinstance(k, int)
if k >= len(x.params):
raise UnificationError(f'Cannot unify {y} with {x}')
v.unify(x.params[k])
y.parent = x
else:
y.unify(x)
def __eq__(self, other):
s = self.find()
o = other.find()
if not isinstance(s, TVar):
return s == o
if isinstance(o, TVar):
return s.id == o.id
return False
class FuncArg:
def __init__(self, name, typ, is_optional):
self.name = name
self.typ = typ
self.is_optional = is_optional
def __str__(self):
return f'{self.name}: {self.typ}' + ('?' if self.is_optional else '')
class TCall(Type):
def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type):
self.calls = [[posargs, kwargs, ret, None]]
self.parent = self
self.rank = 0
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.calls[0][3].check()
def find(self):
root = self
parent = self.parent
while root is not parent and isinstance(parent, TCall):
_, parent = root, root.parent = parent, parent.parent
if parent.calls[0][3] is None:
return parent
return parent.calls[0][3]
def unify(self, other):
y = self.find()
if y is not self:
y.unify(other)
return
x = other.find()
if x is y:
return
if isinstance(x, TCall):
# standard union find
if x.rank < y.rank:
x, y = y, x
y.parent = x
if x.rank == y.rank:
x.rank += 1
x.calls += y.calls
elif isinstance(x, TFunc):
fn = x
for i in range(len(y.calls)):
if not x.instantiated:
fn = x.instantiate()
posargs, kwargs, ret, _ = y.calls[i]
c = y.calls[i]
c[3] = fn
all_args = set(arg.name for arg in fn.args)
required = set(arg.name for arg in fn.args if not
arg.is_optional)
fn.ret.unify(ret)
for i, v in enumerate(posargs):
arg = fn.args[i]
arg.typ.unify(v)
if arg.name in required:
required.remove(arg.name)
for k, v in kwargs.items():
arg = next((arg for arg in fn.args if arg.name == k), None)
if arg is None:
raise UnificationError(f'Unknown kwarg {k}')
if k not in all_args:
raise UnificationError(f'Duplicated kwarg {k}')
arg.typ.unify(v)
if k in required:
required.remove(k)
all_args.remove(k)
if len(required) > 0:
raise UnificationError(f'Missing arguments')
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(f'Cannot unify a call with {other}')
class TFunc(Type):
def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]):
self.args = args
self.ret = ret
self.vars = vars
self.instantiated = False
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
for arg in self.args:
arg.typ.check()
self.ret.check()
def subst(self, mapping: Mapping[int, Type]):
if len(mapping) == 0:
return self
return TFunc(
[FuncArg(arg.name, arg.typ.subst(mapping), arg.is_optional)
for arg in self.args],
self.ret.subst(mapping),
self.vars)
def instantiate(self):
mapping = {v.id: TVar(v.range) if isinstance(v, TVar)
else TVar() for v in self.vars}
result = self.subst(mapping)
result.instantiated = True
return result
def __str__(self):
return f'({", ".join(str(arg) for arg in self.args)}) -> {self.ret}'
def unify(self, other):
other = other.find()
if other is self:
return
if isinstance(other, (TVar, TCall)):
other.unify(self)
elif isinstance(other, TFunc):
if len(self.args) != len(other.args):
raise UnificationError(
f'cannot unify functions with different parameters')
self.ret.unify(other.ret)
for a, b in zip(self.args, other.args):
if a.name != b.name or a.is_optional != b.is_optional:
raise UnificationError(
f'cannot unify functions with different parameters')
a.typ.unify(b.typ)
else:
raise UnificationError(f'Cannot unify a function with {other}')
class TObj(Type):
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type],
parents=None):
self.name = name
self.fields = fields
self.params = params
self.checked = False
if parents is None:
self.parents = []
else:
self.parents = parents
def check(self):
if self.checked:
return
self.checked = True
for arg in self.fields.values():
arg.check()
def subst(self, mapping: Mapping[int, Type]):
if len(mapping) == 0:
return self
new_params = []
changed = False
for v in self.params:
if isinstance(v, TVar) and v.id in mapping:
new_params.append(mapping[v.id])
changed = True
else:
new_params.append(v)
if not changed:
return self
return TObj(self.name, {k: v.subst(mapping) for k, v in
self.fields.items()}, new_params)
def unify(self, other):
other = other.find()
if other is self:
return
if isinstance(other, TObj):
if self.name != other.name:
raise UnificationError(f'Cannot unify {self} with {other}')
for k in self.fields:
self.fields[k].unify(other.fields[k])
elif isinstance(other, TVar):
other.unify(self)
else:
raise UnificationError(f'Cannot unify an object with {other}')
def __str__(self):
if len(self.params) > 0:
p = '[' + ', '.join(str(p) for p in self.params) + ']'
else:
p = ''
return self.name + p
def __eq__(self, other):
o = other.find()
if isinstance(o, TObj):
if self.name != o.name:
return False
for a, b in zip(self.params, o.params):
if a != b:
return False
return True
return False
class TVirtual(Type):
def __init__(self, obj: Type):
self.obj = obj
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.obj.check()
def __eq__(self, other):
o = other.find()
if isinstance(o, TVirtual):
return self == o
return False
def unify(self, other):
o = other.find()
if isinstance(o, TVirtual):
self.obj.unify(o.obj)
elif isinstance(o, TVar):
o.unify(self)
else:
raise UnificationError(f'Cannot unify {self} with {o}')
def __str__(self):
return f'virtual[{self.obj}]'
class TList(Type):
def __init__(self, param: Type):
self.param = param
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.param.check()
def unify(self, other):
other = other.find()
if isinstance(other, TVar):
other.unify(self)
elif isinstance(other, TList):
self.param.unify(other.param)
else:
raise UnificationError(f'Cannot unify list with {other}')
def __str__(self):
return f'List[{self.param}]'
def __eq__(self, other):
o = other.find()
if isinstance(o, TList):
return self.param == o.param
return False
class TTuple(Type):
def __init__(self, params: List[Type]):
self.params = params
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
for p in self.params:
p.check()
def unify(self, other):
other = other.find()
if isinstance(other, TVar):
other.unify(self)
elif isinstance(other, TTuple):
if len(self.params) != len(other.params):
raise UnificationError(
'cannot unify tuples with different length')
for a, b in zip(self.params, other.params):
a.unify(b)
else:
raise UnificationError(f'Cannot unify {self} with {other}')
def __str__(self):
return f'Tuple[{", ".join(str(p) for p in self.params)}]'
def __eq__(self, other):
o = other.find()
if isinstance(o, TTuple):
if len(self.params) != len(o.params):
return False
for a, b in zip(self.params, o.params):
if a != b:
return False
return True
return False

View File

@ -0,0 +1,42 @@
from nac3_types import *
TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, [])
TFloat = TObj('float', {}, [])
TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
def impl_cmp(ty):
ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
def impl_arithmetic(ty):
ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, [])
ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, [])
ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, [])
impl_cmp(TInt)
impl_cmp(TFloat)
impl_arithmetic(TInt)
impl_arithmetic(TFloat)
TNum = TVar([TInt, TFloat])
TInt.fields['__truediv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])
TInt.fields['__floordiv__'] = TFunc(
[FuncArg('other', TNum, False)], TInt, [TNum])
TFloat.fields['__truediv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])
TFloat.fields['__floordiv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])

72
hm-inference/test.py Normal file
View File

@ -0,0 +1,72 @@
from __future__ import annotations
import ast
from ast_visitor import Visitor
from nac3_types import *
from primitives import *
src = """
a = [
virtual(bar),
virtual(foo),
]
for x in a:
test_virtual(x)
"""
foo = TObj('Foo', {
'a': TInt,
}, [])
foo2 = TObj('Foo2', {
'a': TInt,
}, [])
bar = TObj('Bar', {
'a': TInt,
'b': TInt
}, [], [foo])
type_mapping = {
'Foo': foo,
'Foo2': foo2,
'Bar': bar,
}
prelude = {
'foo': foo,
'foo2': foo2,
'bar': bar,
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, [])
}
print('-----------')
print('prelude')
for key, value in prelude.items():
print(f'{key}: {value}')
print('-----------')
v = Visitor(src, prelude.copy(), lambda x: type_mapping[x])
print(src)
v.visit(ast.parse(src))
for a, b in v.virtuals:
assert isinstance(a, TObj)
assert b.find() is a or b.find() in a.parents
print('-----------')
print('calls')
for x in v.calls:
x.check()
print(f'{x.find()}')
print('-----------')
print('assignments')
for key, value in v.assignments.items():
if key not in prelude:
value.check()
print(f'{key}: {value.find()}')