allows recursive type, implementing primitives
This commit is contained in:
parent
66df55b3d7
commit
b1020352ce
@ -1,7 +1,7 @@
|
||||
import ast
|
||||
from itertools import chain
|
||||
from nac3_types import *
|
||||
|
||||
from primitives import *
|
||||
|
||||
class Visitor(ast.NodeVisitor):
|
||||
def __init__(self, src, assignments, type_parser):
|
||||
@ -35,7 +35,7 @@ class Visitor(ast.NodeVisitor):
|
||||
self.visit(node.args)
|
||||
self.visit(node.body)
|
||||
self.assignments = old
|
||||
node.type = TFunc(node.args.type, node.body.type, set())
|
||||
node.type = TFunc(node.args.type, node.body.type, [])
|
||||
|
||||
def visit_arguments(self, node):
|
||||
for arg in node.args:
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Mapping, List, Set
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
|
||||
|
||||
class UnificationError(Exception):
|
||||
@ -53,6 +52,7 @@ class TVar(Type):
|
||||
self.type = TVarType.UNDETERMINED
|
||||
self.rank = 0
|
||||
self.parent = self
|
||||
self.checked = False
|
||||
|
||||
self.fields = {}
|
||||
self.range = vrange
|
||||
@ -60,6 +60,9 @@ class TVar(Type):
|
||||
TVar.next_id += 1
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
if self.range is not None:
|
||||
ty = self.find()
|
||||
# maybe we should replace this with explicit eq
|
||||
@ -179,8 +182,12 @@ class TCall(Type):
|
||||
self.calls = [[posargs, kwargs, ret, None]]
|
||||
self.parent = self
|
||||
self.rank = 0
|
||||
self.checked = False
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
self.calls[0][3].check()
|
||||
|
||||
def find(self):
|
||||
@ -246,13 +253,17 @@ class TCall(Type):
|
||||
|
||||
|
||||
class TFunc(Type):
|
||||
def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]):
|
||||
def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]):
|
||||
self.args = args
|
||||
self.ret = ret
|
||||
self.vars = vars
|
||||
self.instantiated = False
|
||||
self.checked = False
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
for arg in self.args:
|
||||
arg.typ.check()
|
||||
self.ret.check()
|
||||
@ -303,12 +314,16 @@ class TObj(Type):
|
||||
self.name = name
|
||||
self.fields = fields
|
||||
self.params = params
|
||||
self.checked = False
|
||||
if parents is None:
|
||||
self.parents = []
|
||||
else:
|
||||
self.parents = parents
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
for arg in self.fields.values():
|
||||
arg.check()
|
||||
|
||||
@ -361,6 +376,13 @@ class TObj(Type):
|
||||
class TVirtual(Type):
|
||||
def __init__(self, obj: TObj):
|
||||
self.obj = obj
|
||||
self.checked = False
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
self.obj.check()
|
||||
|
||||
def __eq__(self, other):
|
||||
o = other.find()
|
||||
@ -382,8 +404,12 @@ class TVirtual(Type):
|
||||
class TList(Type):
|
||||
def __init__(self, param: Type):
|
||||
self.param = param
|
||||
self.checked = False
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
self.param.check()
|
||||
|
||||
def unify(self, other):
|
||||
@ -408,8 +434,12 @@ class TList(Type):
|
||||
class TTuple(Type):
|
||||
def __init__(self, params: List[Type]):
|
||||
self.params = params
|
||||
self.checked = False
|
||||
|
||||
def check(self):
|
||||
if self.checked:
|
||||
return
|
||||
self.checked = True
|
||||
for p in self.params:
|
||||
p.check()
|
||||
|
||||
@ -441,5 +471,3 @@ class TTuple(Type):
|
||||
return False
|
||||
|
||||
|
||||
TBool = TObj('bool', {}, [])
|
||||
TInt = TObj('int', {}, [])
|
||||
|
42
hm-inference/primitives.py
Normal file
42
hm-inference/primitives.py
Normal file
@ -0,0 +1,42 @@
|
||||
from nac3_types import *
|
||||
|
||||
|
||||
TBool = TObj('bool', {}, [])
|
||||
TInt = TObj('int', {}, [])
|
||||
TFloat = TObj('float', {}, [])
|
||||
|
||||
TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
|
||||
TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
|
||||
|
||||
|
||||
def impl_cmp(ty):
|
||||
ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||
ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||
ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||
ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||
ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||
ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
|
||||
|
||||
|
||||
def impl_arithmetic(ty):
|
||||
ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||
ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||
ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, [])
|
||||
|
||||
|
||||
impl_cmp(TInt)
|
||||
impl_cmp(TFloat)
|
||||
impl_arithmetic(TInt)
|
||||
impl_arithmetic(TFloat)
|
||||
|
||||
TNum = TVar([TInt, TFloat])
|
||||
|
||||
TInt.fields['__truediv__'] = TFunc(
|
||||
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||
TInt.fields['__floordiv__'] = TFunc(
|
||||
[FuncArg('other', TNum, False)], TInt, [TNum])
|
||||
TFloat.fields['__truediv__'] = TFunc(
|
||||
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||
TFloat.fields['__floordiv__'] = TFunc(
|
||||
[FuncArg('other', TNum, False)], TFloat, [TNum])
|
||||
|
@ -2,11 +2,13 @@ from __future__ import annotations
|
||||
import ast
|
||||
from ast_visitor import Visitor
|
||||
from nac3_types import *
|
||||
from primitives import *
|
||||
|
||||
src = """
|
||||
a = 1
|
||||
a = a.__add__(2)
|
||||
b = test_virtual(virtual(bar, Foo))
|
||||
b = test_virtual(virtual(foo, Foo))
|
||||
b = test_virtual(virtual(foo2, Foo))
|
||||
"""
|
||||
|
||||
foo = TObj('Foo', {
|
||||
@ -32,7 +34,7 @@ prelude = {
|
||||
'foo': foo,
|
||||
'foo2': foo2,
|
||||
'bar': bar,
|
||||
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set())
|
||||
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, [])
|
||||
}
|
||||
|
||||
print('-----------')
|
||||
@ -64,5 +66,3 @@ for key, value in v.assignments.items():
|
||||
print(f'{key}: {value.find()}')
|
||||
|
||||
|
||||
|
||||
# TODO: Occur check
|
||||
|
Loading…
Reference in New Issue
Block a user