allows recursive type, implementing primitives

This commit is contained in:
pca006132 2021-07-10 15:11:27 +08:00
parent 66df55b3d7
commit b1020352ce
4 changed files with 80 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import ast import ast
from itertools import chain from itertools import chain
from nac3_types import * from nac3_types import *
from primitives import *
class Visitor(ast.NodeVisitor): class Visitor(ast.NodeVisitor):
def __init__(self, src, assignments, type_parser): def __init__(self, src, assignments, type_parser):
@ -35,7 +35,7 @@ class Visitor(ast.NodeVisitor):
self.visit(node.args) self.visit(node.args)
self.visit(node.body) self.visit(node.body)
self.assignments = old self.assignments = old
node.type = TFunc(node.args.type, node.body.type, set()) node.type = TFunc(node.args.type, node.body.type, [])
def visit_arguments(self, node): def visit_arguments(self, node):
for arg in node.args: for arg in node.args:

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, Mapping, List, Set from typing import Dict, Mapping, List, Set
from enum import Enum from enum import Enum
from itertools import chain
class UnificationError(Exception): class UnificationError(Exception):
@ -53,6 +52,7 @@ class TVar(Type):
self.type = TVarType.UNDETERMINED self.type = TVarType.UNDETERMINED
self.rank = 0 self.rank = 0
self.parent = self self.parent = self
self.checked = False
self.fields = {} self.fields = {}
self.range = vrange self.range = vrange
@ -60,6 +60,9 @@ class TVar(Type):
TVar.next_id += 1 TVar.next_id += 1
def check(self): def check(self):
if self.checked:
return
self.checked = True
if self.range is not None: if self.range is not None:
ty = self.find() ty = self.find()
# maybe we should replace this with explicit eq # maybe we should replace this with explicit eq
@ -179,8 +182,12 @@ class TCall(Type):
self.calls = [[posargs, kwargs, ret, None]] self.calls = [[posargs, kwargs, ret, None]]
self.parent = self self.parent = self
self.rank = 0 self.rank = 0
self.checked = False
def check(self): def check(self):
if self.checked:
return
self.checked = True
self.calls[0][3].check() self.calls[0][3].check()
def find(self): def find(self):
@ -246,13 +253,17 @@ class TCall(Type):
class TFunc(Type): class TFunc(Type):
def __init__(self, args: List[FuncArg], ret: Type, vars: Set[TVar]): def __init__(self, args: List[FuncArg], ret: Type, vars: List[TVar]):
self.args = args self.args = args
self.ret = ret self.ret = ret
self.vars = vars self.vars = vars
self.instantiated = False self.instantiated = False
self.checked = False
def check(self): def check(self):
if self.checked:
return
self.checked = True
for arg in self.args: for arg in self.args:
arg.typ.check() arg.typ.check()
self.ret.check() self.ret.check()
@ -303,12 +314,16 @@ class TObj(Type):
self.name = name self.name = name
self.fields = fields self.fields = fields
self.params = params self.params = params
self.checked = False
if parents is None: if parents is None:
self.parents = [] self.parents = []
else: else:
self.parents = parents self.parents = parents
def check(self): def check(self):
if self.checked:
return
self.checked = True
for arg in self.fields.values(): for arg in self.fields.values():
arg.check() arg.check()
@ -361,6 +376,13 @@ class TObj(Type):
class TVirtual(Type): class TVirtual(Type):
def __init__(self, obj: TObj): def __init__(self, obj: TObj):
self.obj = obj self.obj = obj
self.checked = False
def check(self):
if self.checked:
return
self.checked = True
self.obj.check()
def __eq__(self, other): def __eq__(self, other):
o = other.find() o = other.find()
@ -382,8 +404,12 @@ class TVirtual(Type):
class TList(Type): class TList(Type):
def __init__(self, param: Type): def __init__(self, param: Type):
self.param = param self.param = param
self.checked = False
def check(self): def check(self):
if self.checked:
return
self.checked = True
self.param.check() self.param.check()
def unify(self, other): def unify(self, other):
@ -408,8 +434,12 @@ class TList(Type):
class TTuple(Type): class TTuple(Type):
def __init__(self, params: List[Type]): def __init__(self, params: List[Type]):
self.params = params self.params = params
self.checked = False
def check(self): def check(self):
if self.checked:
return
self.checked = True
for p in self.params: for p in self.params:
p.check() p.check()
@ -441,5 +471,3 @@ class TTuple(Type):
return False return False
TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, [])

View File

@ -0,0 +1,42 @@
from nac3_types import *
TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, [])
TFloat = TObj('float', {}, [])
TBool.fields['__eq__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
TBool.fields['__ne__'] = TFunc([FuncArg('other', TBool, False)], TBool, [])
def impl_cmp(ty):
ty.fields['__lt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__le__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__eq__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__ne__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__gt__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
ty.fields['__ge__'] = TFunc([FuncArg('other', ty, False)], TBool, [])
def impl_arithmetic(ty):
ty.fields['__add__'] = TFunc([FuncArg('other', ty, False)], ty, [])
ty.fields['__sub__'] = TFunc([FuncArg('other', ty, False)], ty, [])
ty.fields['__mul__'] = TFunc([FuncArg('other', ty, False)], ty, [])
impl_cmp(TInt)
impl_cmp(TFloat)
impl_arithmetic(TInt)
impl_arithmetic(TFloat)
TNum = TVar([TInt, TFloat])
TInt.fields['__truediv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])
TInt.fields['__floordiv__'] = TFunc(
[FuncArg('other', TNum, False)], TInt, [TNum])
TFloat.fields['__truediv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])
TFloat.fields['__floordiv__'] = TFunc(
[FuncArg('other', TNum, False)], TFloat, [TNum])

View File

@ -2,11 +2,13 @@ from __future__ import annotations
import ast import ast
from ast_visitor import Visitor from ast_visitor import Visitor
from nac3_types import * from nac3_types import *
from primitives import *
src = """ src = """
a = 1
a = a.__add__(2)
b = test_virtual(virtual(bar, Foo)) b = test_virtual(virtual(bar, Foo))
b = test_virtual(virtual(foo, Foo)) b = test_virtual(virtual(foo, Foo))
b = test_virtual(virtual(foo2, Foo))
""" """
foo = TObj('Foo', { foo = TObj('Foo', {
@ -32,7 +34,7 @@ prelude = {
'foo': foo, 'foo': foo,
'foo2': foo2, 'foo2': foo2,
'bar': bar, 'bar': bar,
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set()) 'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, [])
} }
print('-----------') print('-----------')
@ -64,5 +66,3 @@ for key, value in v.assignments.items():
print(f'{key}: {value.find()}') print(f'{key}: {value.find()}')
# TODO: Occur check