From 59628cfa38272f29f0a99e3514dbea29275d5322 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Fri, 9 Jul 2021 15:27:02 +0800 Subject: [PATCH] init hm-inference --- hm-inference/ast_visitor.py | 160 ++++++++++++++++ hm-inference/nac3_types.py | 365 ++++++++++++++++++++++++++++++++++++ hm-inference/test.py | 52 +++++ 3 files changed, 577 insertions(+) create mode 100644 hm-inference/ast_visitor.py create mode 100644 hm-inference/nac3_types.py create mode 100644 hm-inference/test.py diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py new file mode 100644 index 0000000..b1b6cad --- /dev/null +++ b/hm-inference/ast_visitor.py @@ -0,0 +1,160 @@ +import ast +from itertools import chain +from nac3_types import * + + +class Visitor(ast.NodeVisitor): + def __init__(self): + super(Visitor, self).__init__() + self.assignments = {} + self.calls = [] + + def visit_Assign(self, node): + self.visit(node.value) + for target in node.targets: + self.visit(target) + target.type.unify(node.value.type) + + def visit_Name(self, node): + if node.id == '_': + node.type = TVar() + return + if node.id not in self.assignments: + self.assignments[node.id] = TVar() + ty = self.assignments[node.id] + if isinstance(ty, TFunc): + ty = ty.instantiate() + node.type = ty + + def visit_Lambda(self, node): + old = self.assignments + self.assignments = old.copy() + self.visit(node.args) + self.visit(node.body) + self.assignments = old + node.type = TFunc(node.args.type, node.body.type, set()) + + def visit_arguments(self, node): + for arg in node.args: + self.assignments[arg.arg] = TVar() + arg.type = self.assignments[arg.arg] + node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args] + + def visit_Constant(self, node): + if isinstance(node.value, bool): + node.type = TBool + elif isinstance(node.value, int): + node.type = TInt + + def visit_Call(self, node): + self.visit(node.func) + for arg in node.args: + self.visit(arg) + for keyword in node.keywords: + self.visit(keyword.value) + node.type = TVar() + call = TCall([arg.type for arg in node.args], + {keyword.arg: keyword.value.type for keyword + in node.keywords}, node.type) + node.func.type.unify(call) + self.calls.append(node.func.type) + + def visit_Attribute(self, node): + self.visit(node.value) + node.type = TVar() + v = TVar() + v.type = TVarType.RECORD + if node.attr in v.fields: + v.fields[node.attr].unify(node.type) + else: + v.fields[node.attr] = node.type + node.value.type.unify(v) + + def visit_Tuple(self, node): + for e in node.elts: + self.visit(e) + node.type = TTuple([e.type for e in node.elts]) + + def visit_List(self, node): + ty = TVar() + for e in node.elts: + self.visit(e) + ty.unify(e.type) + node.type = TList(ty) + + def visit_Subscript(self, node): + self.visit(node.value) + if isinstance(node.slice, ast.Slice): + node.type = node.value.type + if isinstance(node.type, TVar): + node.type.type = node.type.type.unifier(TVarType.LIST) + elif not isinstance(node.type, TList): + raise UnificationError(f'{node.type} should be a list') + elif isinstance(node.slice, ast.ExtSlice): + raise NotImplementedError() + else: + # complicated because we need to handle lists and tuples + # differently... + self.visit(node.slice.value) + node.slice.value.type.unify(TInt) + ty = node.value.type + node.type = TVar() + if isinstance(ty, TVar): + index = 0 + if isinstance(node.slice.value, ast.Constant): + seq_ty = TVarType.SEQUENCE + if isinstance(node.ctx, (ast.AugStore, ast.Store)): + seq_ty = TVarType.LIST + ty.type = ty.type.unifier(seq_ty) + index = node.slice.value.value + else: + ty.type = ty.type.unifier(TVarType.LIST) + if index in ty.fields: + ty.fields[index].unify(node.type) + else: + ty.fields[index] = node.type + elif isinstance(ty, TList): + ty.param.unify(node.type) + elif isinstance(ty, TTuple): + if isinstance(node.ctx, (ast.AugStore, ast.Store)): + raise UnificationError(f'Cannot assign to tuple') + if isinstance(node.slice.value, ast.Constant): + index = node.slice.value.value + if index >= len(ty.params): + raise UnificationError('Index out of range for tuple') + ty.params[index].unify(node.type) + else: + raise UnificationError('Tuple index must be a constant') + else: + raise UnificationError(f'Cannot use subscript for {ty}') + + def visit_For(self, node): + self.visit(node.target) + self.visit(node.iter) + ty = node.iter.type + if isinstance(ty, TVar): + # we currently only support iterator over lists + ty.type = ty.type.unifier(TVarType.LIST) + if 0 in ty.fields: + ty.fields[0].unify(node.target.type) + else: + ty.fields[0] = node.target.type + elif isinstance(ty, TList): + ty.param.unify(node.target.type) + else: + raise UnificationError(f'Cannot iterate over {ty}') + + for stmt in chain(node.body, node.orelse): + self.visit(stmt) + + def visit_If(self, node): + self.visit(node.test) + node.test.type.unify(TBool) + for stmt in chain(node.body, node.orelse): + self.visit(stmt) + + def visit_While(self, node): + self.visit(node.test) + node.test.type.unify(TBool) + for stmt in chain(node.body, node.orelse): + self.visit(stmt) diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py new file mode 100644 index 0000000..b176b78 --- /dev/null +++ b/hm-inference/nac3_types.py @@ -0,0 +1,365 @@ +from __future__ import annotations +from typing import Dict, Mapping, List, Set +from enum import Enum +from itertools import chain + + +class UnificationError(Exception): + def __init__(self, msg): + super().__init__(msg) + + +class Type: + def find(self): + return self + + def unify(self, _): + raise NotImplementedError() + + def subst(self, _): + raise NotImplementedError() + + def check(self): + pass + + +class TVarType(Enum): + UNDETERMINED = 1 + SEQUENCE = 2 + RECORD = 5 + TUPLE = 6 + LIST = 8 + + def __le__(self, other): + if self.__class__ is other.__class__: + return (other.value % self.value) == 0 + return NotImplemented + + def unifier(self, other): + if self.__class__ is not other.__class__: + raise NotImplementedError() + if self <= other: + return other + elif other <= self: + return self + else: + raise UnificationError(f'cannot unify {self} and {other}') + + +class TVar(Type): + next_id = 0 + + def __init__(self, vrange=None): + self.type = TVarType.UNDETERMINED + self.rank = 0 + self.parent = self + + self.fields = {} + self.range = vrange + self.id = TVar.next_id + TVar.next_id += 1 + + def check(self): + if self.range is not None: + ty = self.find() + # maybe we should replace this with explicit eq + if ty not in self.range: + raise UnificationError( + f'{self.id} cannot be substituted by {ty}') + + def subst(self, mapping: Mapping[int, Type]): + # user cannot specify fields... + # so this is safe + if self.id in mapping: + return mapping[self.id] + return self + + def __str__(self): + s = self.find() + if isinstance(s, TVar): + if len(s.fields) > 0: + fields = '{' + ', '.join([f'{k}: {v}' for k, + v in s.fields.items()]) + '}' + else: + fields = '' + return str(s.id) + fields + else: + return str(s) + + def find(self): + root = self + parent = self.parent + while root is not parent and isinstance(parent, TVar): + _, parent = root, root.parent = parent, parent.parent + if isinstance(parent, TCall): + parent = parent.find() + return parent + + def unify(self, other): + x = other.find() + y = self.find() + + if x is y: + return + if isinstance(y, TVar) and isinstance(x, TVar): + # unify field type + x.type = x.type.unifier(y.type) + # unify fields + for k, v in y.fields.items(): + if k in x.fields: + x.fields[k].unify(v) + else: + x.fields[k] = v + # standard union find + if x.rank < y.rank: + x, y = y, x + y.parent = x + if x.rank == y.rank: + x.rank += 1 + elif isinstance(y, TVar): + # check fields + if isinstance(x, TObj): + if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: + raise UnificationError(f'Cannot unify {y} with {x}') + for k, v in y.fields.items(): + if k not in x.fields: + raise UnificationError( + f'Cannot unify {y} with {x}') + if isinstance(v, TFunc): + v = v.instantiate() + u = x.fields[k] + if isinstance(u, TFunc): + u = u.instantiate() + v.unify(u) + if isinstance(x, TList): + if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]: + raise UnificationError(f'Cannot unify {y} with {x}') + for k, v in y.fields.items(): + assert isinstance(k, int) + v.unify(x.param) + if isinstance(x, TTuple): + if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]: + raise UnificationError(f'Cannot unify {y} with {x}') + for k, v in y.fields.items(): + assert isinstance(k, int) + if k >= len(x.params): + raise UnificationError(f'Cannot unify {y} with {x}') + v.unify(x.params[k]) + y.parent = x + else: + y.unify(x) + + +class FuncArg: + def __init__(self, name, typ, is_optional): + self.name = name + self.typ = typ + self.is_optional = is_optional + + def __str__(self): + return f'{self.name}: {self.typ}' + ('?' if self.is_optional else '') + + +class TCall(Type): + def __init__(self, posargs: List[Type], kwargs: Dict[str, Type], ret: Type): + self.posargs = posargs + self.kwargs = kwargs + self.ret = ret + self.fun = TVar() + + def check(self): + for arg in chain(self.posargs, self.kwargs.values()): + arg.check() + + def find(self): + if isinstance(self.fun.find(), TVar): + return self + return self.fun.find() + + def unify(self, other): + if not isinstance(self.fun.find(), TVar): + self.fun.unify(other) + return + + other = other.find() + if other is self: + return + + if isinstance(other, TCall): + for a, b in zip(self.posargs, other.posargs): + a.unify(b) + for k, v in self.kwargs.items(): + if k in other.kwargs: + other.kwargs[k].unify(v) + else: + other.kwargs[k] = v + for k, v in other.kwargs.items(): + if k not in self.kwargs: + self.kwargs[k] = v + self.fun.unify(other.fun) + elif isinstance(other, TFunc): + all_args = set(arg.name for arg in other.args) + required = set(arg.name for arg in other.args if not + arg.is_optional) + other.ret.unify(self.ret) + for i, v in enumerate(self.posargs): + arg = other.args[i] + arg.typ.unify(v) + if arg.name in required: + required.remove(arg.name) + for k, v in self.kwargs.items(): + arg = next((arg for arg in other.args if arg.name == k), None) + if arg is None: + raise UnificationError(f'Unknown kwarg {k}') + if k not in all_args: + raise UnificationError(f'Duplicated kwarg {k}') + arg.typ.unify(v) + if k in required: + required.remove(k) + all_args.remove(k) + if len(required) > 0: + raise UnificationError(f'Missing arguments') + self.fun.unify(other) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(f'Cannot unify a call with {other}') + + +class TFunc(Type): + def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]): + self.args = args + self.ret = ret + self.vars = vars + self.instantiated = False + + def check(self): + for arg in self.args: + arg.typ.check() + self.ret.check() + + def subst(self, mapping: Mapping[int, Type]): + if len(mapping) == 0: + return self + return TFunc( + [FuncArg(arg.name, arg.typ.subst(mapping), arg.is_optional) + for arg in self.args], + self.ret.subst(mapping), + self.vars) + + def instantiate(self): + mapping = {v.id: TVar(v.range) if isinstance(v, TVar) + else TVar() for v in self.vars} + result = self.subst(mapping) + result.instantiated = True + return result + + def __str__(self): + return f'({", ".join(str(arg) for arg in self.args)}) -> {self.ret}' + + def unify(self, other): + other = other.find() + if other is self: + return + + if isinstance(other, (TVar, TCall)): + other.unify(self) + elif isinstance(other, TFunc): + if len(self.args) != len(other.args): + raise UnificationError( + f'cannot unify functions with different parameters') + self.ret.unify(other.ret) + for a, b in zip(self.args, other.args): + if a.name != b.name or a.is_optional != b.is_optional: + raise UnificationError( + f'cannot unify functions with different parameters') + a.typ.unify(b.typ) + else: + raise UnificationError(f'Cannot unify a function with {other}') + + +class TObj(Type): + def __init__(self, name: str, fields: Dict[str, Type], params: List[Type]): + self.name = name + self.fields = fields + self.params = params + + def check(self): + for arg in self.fields.values(): + arg.check() + + def subst(self, mapping: Mapping[int, Type]): + if len(mapping) == 0: + return self + new_params = [] + for v in self.params: + if isinstance(v, TVar) and v.id in mapping: + new_params.append(mapping[v.id]) + else: + new_params.append(v) + + return TObj(self.name, {k: v.subst(mapping) for k, v in + self.fields.items()}, new_params) + + def unify(self, other): + other = other.find() + if other is self: + return + if isinstance(other, TObj): + if self.name != other.name: + raise UnificationError(f'Cannot unify {self} with {other}') + for k in self.fields: + self.fields[k].unify(other.fields[k]) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(f'Cannot unify an object with {other}') + + def __str__(self): + if len(self.params) > 0: + p = '[' + ', '.join(str(p) for p in self.params) + ']' + else: + p = '' + return self.name + p + + +class TList(Type): + def __init__(self, param: Type): + self.param = param + + def unify(self, other): + other = other.find() + if isinstance(other, TVar): + other.unify(self) + elif isinstance(other, TList): + self.param.unify(other.param) + else: + raise UnificationError(f'Cannot unify list with {other}') + + def __str__(self): + return f'List[{self.param}]' + + +class TTuple(Type): + def __init__(self, params: List[Type]): + self.params = params + + def unify(self, other): + other = other.find() + if isinstance(other, TVar): + other.unify(self) + elif isinstance(other, TTuple): + if len(self.params) != len(other.params): + raise UnificationError( + 'cannot unify tuples with different length') + for a, b in zip(self.params, other.params): + a.unify(b) + else: + raise UnificationError(f'Cannot unify {self} with {other}') + + def __str__(self): + return f'Tuple[{", ".join(str(p) for p in self.params)}]' + + +TBool = TObj('bool', {}, []) +TInt = TObj('int', {}, []) diff --git a/hm-inference/test.py b/hm-inference/test.py new file mode 100644 index 0000000..0e6a75b --- /dev/null +++ b/hm-inference/test.py @@ -0,0 +1,52 @@ +# from __future__ import annotations +import ast +from ast_visitor import Visitor +from nac3_types import * + + +var = TVar([TInt, TBool]) +var2 = TVar() + +foo = TObj('Foo', { + 'foo': TFunc([ + FuncArg('a', var, False), + FuncArg('b', var2, False) + ], var2, set([var2])) +}, [var]) + +v = Visitor() +v.assignments['get_x'] = TFunc([FuncArg('in', var, False)], TInt, set([var])) +v.assignments['Foo'] = TFunc([FuncArg('a', var, False)], foo, set([var])) + +prelude = set(v.assignments.keys()) +print('-----------') +print('prelude') +for key, value in v.assignments.items(): + print(f'{key}: {value}') +print('-----------') + +src = """ +a = Foo(1).foo(1, 2) +b = Foo(1).foo(1, True) +c = Foo(True).foo(True, 1) +d = Foo(True).foo(True, True) +""" + +print(src) +v.visit(ast.parse(src)) + +print('-----------') +print('assignments') +for key, value in v.assignments.items(): + if key not in prelude: + value.check() + print(f'{key}: {value.find()}') + +print('-----------') +print('calls') +for x in v.calls: + x.check() + print(f'{x.find()}') + + +# TODO: Occur check