allows recursive type, implementing primitives
This commit is contained in:
parent
66df55b3d7
commit
b1020352ce
@ -1,7 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from nac3_types import *
|
from nac3_types import *
|
||||||
|
from primitives import *
|
||||||
|
|
||||||
class Visitor(ast.NodeVisitor):
|
class Visitor(ast.NodeVisitor):
|
||||||
def __init__(self, src, assignments, type_parser):
|
def __init__(self, src, assignments, type_parser):
|
||||||
@ -35,7 +35,7 @@ class Visitor(ast.NodeVisitor):
|
|||||||
self.visit(node.args)
|
self.visit(node.args)
|
||||||
self.visit(node.body)
|
self.visit(node.body)
|
||||||
self.assignments = old
|
self.assignments = old
|
||||||
node.type = TFunc(node.args.type, node.body.type, set())
|
node.type = TFunc(node.args.type, node.body.type, [])
|
||||||
|
|
||||||
def visit_arguments(self, node):
|
def visit_arguments(self, node):
|
||||||
for arg in node.args:
|
for arg in node.args:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Dict, Mapping, List, Set
|
from typing import Dict, Mapping, List, Set
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
|
|
||||||
class UnificationError(Exception):
|
class UnificationError(Exception):
|
||||||
@ -53,6 +52,7 @@ class TVar(Type):
|
|||||||
self.type = TVarType.UNDETERMINED
|
self.type = TVarType.UNDETERMINED
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
self.parent = self
|
self.parent = self
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
self.fields = {}
|
self.fields = {}
|
||||||
self.range = vrange
|
self.range = vrange
|
||||||
@ -60,6 +60,9 @@ class TVar(Type):
|
|||||||
TVar.next_id += 1
|
TVar.next_id += 1
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
if self.range is not None:
|
if self.range is not None:
|
||||||
ty = self.find()
|
ty = self.find()
|
||||||
# maybe we should replace this with explicit eq
|
# maybe we should replace this with explicit eq
|
||||||
@ -179,8 +182,12 @@ class TCall(Type):
|
|||||||
self.calls = [[posargs, kwargs, ret, None]]
|
self.calls = [[posargs, kwargs, ret, None]]
|
||||||
self.parent = self
|
self.parent = self
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
self.calls[0][3].check()
|
self.calls[0][3].check()
|
||||||
|
|
||||||
def find(self):
|
def find(self):
|
||||||
@ -246,13 +253,17 @@ class TCall(Type):
|
|||||||
|
|
||||||
|
|
||||||
class TFunc(Type):
|
class TFunc(Type):
|
||||||
def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]):
|
def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.ret = ret
|
self.ret = ret
|
||||||
self.vars = vars
|
self.vars = vars
|
||||||
self.instantiated = False
|
self.instantiated = False
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
for arg in self.args:
|
for arg in self.args:
|
||||||
arg.typ.check()
|
arg.typ.check()
|
||||||
self.ret.check()
|
self.ret.check()
|
||||||
@ -303,12 +314,16 @@ class TObj(Type):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.fields = fields
|
self.fields = fields
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.checked = False
|
||||||
if parents is None:
|
if parents is None:
|
||||||
self.parents = []
|
self.parents = []
|
||||||
else:
|
else:
|
||||||
self.parents = parents
|
self.parents = parents
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
for arg in self.fields.values():
|
for arg in self.fields.values():
|
||||||
arg.check()
|
arg.check()
|
||||||
|
|
||||||
@ -361,6 +376,13 @@ class TObj(Type):
|
|||||||
class TVirtual(Type):
|
class TVirtual(Type):
|
||||||
def __init__(self, obj: TObj):
|
def __init__(self, obj: TObj):
|
||||||
self.obj = obj
|
self.obj = obj
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
|
self.obj.check()
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
o = other.find()
|
o = other.find()
|
||||||
@ -382,8 +404,12 @@ class TVirtual(Type):
|
|||||||
class TList(Type):
|
class TList(Type):
|
||||||
def __init__(self, param: Type):
|
def __init__(self, param: Type):
|
||||||
self.param = param
|
self.param = param
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
self.param.check()
|
self.param.check()
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
@ -408,8 +434,12 @@ class TList(Type):
|
|||||||
class TTuple(Type):
|
class TTuple(Type):
|
||||||
def __init__(self, params: List[Type]):
|
def __init__(self, params: List[Type]):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.checked = False
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
if self.checked:
|
||||||
|
return
|
||||||
|
self.checked = True
|
||||||
for p in self.params:
|
for p in self.params:
|
||||||
p.check()
|
p.check()
|
||||||
|
|
||||||
@ -441,5 +471,3 @@ class TTuple(Type):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
TBool = TObj('bool', {}, [])
|
|
||||||
TInt = TObj('int', {}, [])
|
|
||||||
|
42
hm-inference/primitives.py
Normal file
42
hm-inference/primitives.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from nac3_types import *
|
||||||
|
|
||||||
|
|
||||||
|
TBool = TObj('bool', {}, [])
|
||||||
|
TInt = TObj('int', {}, [])
|
||||||
|
TFloat = TObj('float', {}, [])
|
||||||
|
|
||||||
|
TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
|
||||||
|
TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
|
||||||
|
|
||||||
|
|
||||||
|
def impl_cmp(ty):
|
||||||
|
ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||||
|
|
||||||
|
|
||||||
|
def impl_arithmetic(ty):
|
||||||
|
ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||||
|
ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||||
|
ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||||
|
|
||||||
|
|
||||||
|
impl_cmp(TInt)
|
||||||
|
impl_cmp(TFloat)
|
||||||
|
impl_arithmetic(TInt)
|
||||||
|
impl_arithmetic(TFloat)
|
||||||
|
|
||||||
|
TNum = TVar([TInt, TFloat])
|
||||||
|
|
||||||
|
TInt.fields['__truediv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||||
|
TInt.fields['__floordiv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TInt, [TNum])
|
||||||
|
TFloat.fields['__truediv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||||
|
TFloat.fields['__floordiv__'] = TFunc(
|
||||||
|
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||||
|
|
@ -2,11 +2,13 @@ from __future__ import annotations
|
|||||||
import ast
|
import ast
|
||||||
from ast_visitor import Visitor
|
from ast_visitor import Visitor
|
||||||
from nac3_types import *
|
from nac3_types import *
|
||||||
|
from primitives import *
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
|
a = 1
|
||||||
|
a = a.__add__(2)
|
||||||
b = test_virtual(virtual(bar, Foo))
|
b = test_virtual(virtual(bar, Foo))
|
||||||
b = test_virtual(virtual(foo, Foo))
|
b = test_virtual(virtual(foo, Foo))
|
||||||
b = test_virtual(virtual(foo2, Foo))
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
foo = TObj('Foo', {
|
foo = TObj('Foo', {
|
||||||
@ -32,7 +34,7 @@ prelude = {
|
|||||||
'foo': foo,
|
'foo': foo,
|
||||||
'foo2': foo2,
|
'foo2': foo2,
|
||||||
'bar': bar,
|
'bar': bar,
|
||||||
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set())
|
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, [])
|
||||||
}
|
}
|
||||||
|
|
||||||
print('-----------')
|
print('-----------')
|
||||||
@ -64,5 +66,3 @@ for key, value in v.assignments.items():
|
|||||||
print(f'{key}: {value.find()}')
|
print(f'{key}: {value.find()}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Occur check
|
|
||||||
|
Loading…
Reference in New Issue
Block a user