diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py index 634caff..bda045e 100644 --- a/hm-inference/ast_visitor.py +++ b/hm-inference/ast_visitor.py @@ -1,7 +1,7 @@ import ast from itertools import chain from nac3_types import * - +from primitives import * class Visitor(ast.NodeVisitor): def __init__(self, src, assignments, type_parser): @@ -35,7 +35,7 @@ class Visitor(ast.NodeVisitor): self.visit(node.args) self.visit(node.body) self.assignments = old - node.type = TFunc(node.args.type, node.body.type, set()) + node.type = TFunc(node.args.type, node.body.type, []) def visit_arguments(self, node): for arg in node.args: diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index e7d3e18..c8d579e 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Dict, Mapping, List, Set from enum import Enum -from itertools import chain class UnificationError(Exception): @@ -53,6 +52,7 @@ class TVar(Type): self.type = TVarType.UNDETERMINED self.rank = 0 self.parent = self + self.checked = False self.fields = {} self.range = vrange @@ -60,6 +60,9 @@ class TVar(Type): TVar.next_id += 1 def check(self): + if self.checked: + return + self.checked = True if self.range is not None: ty = self.find() # maybe we should replace this with explicit eq @@ -179,8 +182,12 @@ class TCall(Type): self.calls = [[posargs, kwargs, ret, None]] self.parent = self self.rank = 0 + self.checked = False def check(self): + if self.checked: + return + self.checked = True self.calls[0][3].check() def find(self): @@ -246,13 +253,17 @@ class TCall(Type): class TFunc(Type): - def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]): + def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]): self.args = args self.ret = ret self.vars = vars self.instantiated = False + self.checked = False def check(self): + if self.checked: + return + self.checked = True for arg in self.args: arg.typ.check() self.ret.check() @@ -303,12 +314,16 @@ class TObj(Type): self.name = name self.fields = fields self.params = params + self.checked = False if parents is None: self.parents = [] else: self.parents = parents def check(self): + if self.checked: + return + self.checked = True for arg in self.fields.values(): arg.check() @@ -361,6 +376,13 @@ class TObj(Type): class TVirtual(Type): def __init__(self, obj: TObj): self.obj = obj + self.checked = False + + def check(self): + if self.checked: + return + self.checked = True + self.obj.check() def __eq__(self, other): o = other.find() @@ -382,8 +404,12 @@ class TVirtual(Type): class TList(Type): def __init__(self, param: Type): self.param = param + self.checked = False def check(self): + if self.checked: + return + self.checked = True self.param.check() def unify(self, other): @@ -408,8 +434,12 @@ class TList(Type): class TTuple(Type): def __init__(self, params: List[Type]): self.params = params + self.checked = False def check(self): + if self.checked: + return + self.checked = True for p in self.params: p.check() @@ -441,5 +471,3 @@ class TTuple(Type): return False -TBool = TObj('bool', {}, []) -TInt = TObj('int', {}, []) diff --git a/hm-inference/primitives.py b/hm-inference/primitives.py new file mode 100644 index 0000000..43c4e47 --- /dev/null +++ b/hm-inference/primitives.py @@ -0,0 +1,42 @@ +from nac3_types import * + + +TBool = TObj('bool', {}, []) +TInt = TObj('int', {}, []) +TFloat = TObj('float', {}, []) + +TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, []) +TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, []) + + +def impl_cmp(ty): + ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, []) + ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, []) + ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, []) + ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, []) + ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, []) + ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, []) + + +def impl_arithmetic(ty): + ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, []) + ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, []) + ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, []) + + +impl_cmp(TInt) +impl_cmp(TFloat) +impl_arithmetic(TInt) +impl_arithmetic(TFloat) + +TNum = TVar([TInt, TFloat]) + +TInt.fields['__truediv__'] = TFunc( + [FuncArg('other', TNum, False)], TFloat, [TNum]) +TInt.fields['__floordiv__'] = TFunc( + [FuncArg('other', TNum, False)], TInt, [TNum]) +TFloat.fields['__truediv__'] = TFunc( + [FuncArg('other', TNum, False)], TFloat, [TNum]) +TFloat.fields['__floordiv__'] = TFunc( + [FuncArg('other', TNum, False)], TFloat, [TNum]) + diff --git a/hm-inference/test.py b/hm-inference/test.py index 6ded82d..6e8adea 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -2,11 +2,13 @@ from __future__ import annotations import ast from ast_visitor import Visitor from nac3_types import * +from primitives import * src = """ +a = 1 +a = a.__add__(2) b = test_virtual(virtual(bar, Foo)) b = test_virtual(virtual(foo, Foo)) -b = test_virtual(virtual(foo2, Foo)) """ foo = TObj('Foo', { @@ -32,7 +34,7 @@ prelude = { 'foo': foo, 'foo2': foo2, 'bar': bar, - 'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set()) + 'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, []) } print('-----------') @@ -64,5 +66,3 @@ for key, value in v.assignments.items(): print(f'{key}: {value.find()}') - -# TODO: Occur check