allows recursive type, implementing primitives

This commit is contained in:
pca006132 2021-07-10 15:11:27 +08:00
parent 66df55b3d7
commit b1020352ce
4 changed files with 80 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import ast
from itertools import chain
from nac3_types import *
from primitives import *
class Visitor(ast.NodeVisitor):
def __init__(self, src, assignments, type_parser):
@ -35,7 +35,7 @@ class Visitor(ast.NodeVisitor):
self.visit(node.args)
self.visit(node.body)
self.assignments = old
node.type = TFunc(node.args.type, node.body.type, set())
node.type = TFunc(node.args.type, node.body.type, [])
def visit_arguments(self, node):
for arg in node.args:

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from typing import Dict, Mapping, List, Set
from enum import Enum
from itertools import chain
class UnificationError(Exception):
@ -53,6 +52,7 @@ class TVar(Type):
self.type = TVarType.UNDETERMINED
self.rank = 0
self.parent = self
self.checked = False
self.fields = {}
self.range = vrange
@ -60,6 +60,9 @@ class TVar(Type):
TVar.next_id += 1
def check(self):
if self.checked:
return
self.checked = True
if self.range is not None:
ty = self.find()
# maybe we should replace this with explicit eq
@ -179,8 +182,12 @@ class TCall(Type):
self.calls = [[posargs, kwargs, ret, None]]
self.parent = self
self.rank = 0
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.calls[0][3].check()
def find(self):
@ -246,13 +253,17 @@ class TCall(Type):
class TFunc(Type):
def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]):
def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]):
self.args = args
self.ret = ret
self.vars = vars
self.instantiated = False
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
for arg in self.args:
arg.typ.check()
self.ret.check()
@ -303,12 +314,16 @@ class TObj(Type):
self.name = name
self.fields = fields
self.params = params
self.checked = False
if parents is None:
self.parents = []
else:
self.parents = parents
def check(self):
if self.checked:
return
self.checked = True
for arg in self.fields.values():
arg.check()
@ -361,6 +376,13 @@ class TObj(Type):
class TVirtual(Type):
def __init__(self, obj: TObj):
self.obj = obj
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.obj.check()
def __eq__(self, other):
o = other.find()
@ -382,8 +404,12 @@ class TVirtual(Type):
class TList(Type):
def __init__(self, param: Type):
self.param = param
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.param.check()
def unify(self, other):
@ -408,8 +434,12 @@ class TList(Type):
class TTuple(Type):
def __init__(self, params: List[Type]):
self.params = params
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
for p in self.params:
p.check()
@ -441,5 +471,3 @@ class TTuple(Type):
return False
TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, [])

View File

@ -0,0 +1,42 @@
from nac3_types import *
TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, [])
TFloat = TObj('float', {}, [])
TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
def impl_cmp(ty):
ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
def impl_arithmetic(ty):
ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, [])
ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, [])
ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, [])
impl_cmp(TInt)
impl_cmp(TFloat)
impl_arithmetic(TInt)
impl_arithmetic(TFloat)
TNum = TVar([TInt, TFloat])
TInt.fields['__truediv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])
TInt.fields['__floordiv__'] = TFunc(
[FuncArg('other', TNum, False)], TInt, [TNum])
TFloat.fields['__truediv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])
TFloat.fields['__floordiv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])

View File

@ -2,11 +2,13 @@ from __future__ import annotations
import ast
from ast_visitor import Visitor
from nac3_types import *
from primitives import *
src = """
a = 1
a = a.__add__(2)
b = test_virtual(virtual(bar, Foo))
b = test_virtual(virtual(foo, Foo))
b = test_virtual(virtual(foo2, Foo))
"""
foo = TObj('Foo', {
@ -32,7 +34,7 @@ prelude = {
'foo': foo,
'foo2': foo2,
'bar': bar,
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set())
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, [])
}
print('-----------')
@ -64,5 +66,3 @@ for key, value in v.assignments.items():
print(f'{key}: {value.find()}')
# TODO: Occur check