virtual type

This commit is contained in:
pca006132 2021-07-10 14:36:28 +08:00
parent 0c8029d7e1
commit 66df55b3d7
3 changed files with 130 additions and 33 deletions

View File

@ -4,10 +4,13 @@ from nac3_types import *
class Visitor(ast.NodeVisitor): class Visitor(ast.NodeVisitor):
def __init__(self): def __init__(self, src, assignments, type_parser):
super(Visitor, self).__init__() super(Visitor, self).__init__()
self.assignments = {} self.source = src
self.assignments = assignments
self.calls = [] self.calls = []
self.virtuals = []
self.type_parser = type_parser
def visit_Assign(self, node): def visit_Assign(self, node):
self.visit(node.value) self.visit(node.value)
@ -47,6 +50,15 @@ class Visitor(ast.NodeVisitor):
node.type = TInt node.type = TInt
def visit_Call(self, node): def visit_Call(self, node):
if ast.get_source_segment(self.source, node.func) == 'virtual':
if len(node.args) != 2:
raise UnificationError('Incorrect argument number for virtual')
self.visit(node.args[0])
ty = self.type_parser(ast.get_source_segment(self.source,
node.args[1]))
self.virtuals.append((node.args[0].type, ty))
node.type = TVirtual(ty)
return
self.visit(node.func) self.visit(node.func)
for arg in node.args: for arg in node.args:
self.visit(arg) self.visit(arg)

View File

@ -118,7 +118,16 @@ class TVar(Type):
x.rank += 1 x.rank += 1
elif isinstance(y, TVar): elif isinstance(y, TVar):
# check fields # check fields
if isinstance(x, TObj): if isinstance(x, TVirtual):
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items():
if k not in x.obj.fields:
raise UnificationError(
f'Cannot unify {y} with {x}')
u = x.obj.fields[k]
v.unify(u)
elif isinstance(x, TObj):
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
raise UnificationError(f'Cannot unify {y} with {x}') raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items(): for k, v in y.fields.items():
@ -127,13 +136,13 @@ class TVar(Type):
f'Cannot unify {y} with {x}') f'Cannot unify {y} with {x}')
u = x.fields[k] u = x.fields[k]
v.unify(u) v.unify(u)
if isinstance(x, TList): elif isinstance(x, TList):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]: if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
raise UnificationError(f'Cannot unify {y} with {x}') raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items(): for k, v in y.fields.items():
assert isinstance(k, int) assert isinstance(k, int)
v.unify(x.param) v.unify(x.param)
if isinstance(x, TTuple): elif isinstance(x, TTuple):
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]: if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
raise UnificationError(f'Cannot unify {y} with {x}') raise UnificationError(f'Cannot unify {y} with {x}')
for k, v in y.fields.items(): for k, v in y.fields.items():
@ -145,6 +154,15 @@ class TVar(Type):
else: else:
y.unify(x) y.unify(x)
def __eq__(self, other):
s = self.find()
o = other.find()
if not isinstance(s, TVar):
return s == o
if isinstance(o, TVar):
return s.id == o.id
return False
class FuncArg: class FuncArg:
def __init__(self, name, typ, is_optional): def __init__(self, name, typ, is_optional):
@ -280,10 +298,15 @@ class TFunc(Type):
class TObj(Type): class TObj(Type):
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type]): def __init__(self, name: str, fields: Dict[str, Type], params: List[Type],
parents=None):
self.name = name self.name = name
self.fields = fields self.fields = fields
self.params = params self.params = params
if parents is None:
self.parents = []
else:
self.parents = parents
def check(self): def check(self):
for arg in self.fields.values(): for arg in self.fields.values():
@ -323,6 +346,38 @@ class TObj(Type):
p = '' p = ''
return self.name + p return self.name + p
def __eq__(self, other):
o = other.find()
if isinstance(o, TObj):
if self.name != o.name:
return False
for a, b in zip(self.params, o.params):
if a != b:
return False
return True
return False
class TVirtual(Type):
def __init__(self, obj: TObj):
self.obj = obj
def __eq__(self, other):
o = other.find()
if isinstance(o, TVirtual):
return self == o
return False
def unify(self, other):
o = other.find()
if isinstance(o, TVirtual):
self.obj.unify(o.obj)
else:
raise UnificationError(f'Cannot unify {self} with {o}')
def __str__(self):
return f'virtual[{self.obj}]'
class TList(Type): class TList(Type):
def __init__(self, param: Type): def __init__(self, param: Type):
@ -343,6 +398,12 @@ class TList(Type):
def __str__(self): def __str__(self):
return f'List[{self.param}]' return f'List[{self.param}]'
def __eq__(self, other):
o = other.find()
if isinstance(o, TList):
return self.param == o.param
return False
class TTuple(Type): class TTuple(Type):
def __init__(self, params: List[Type]): def __init__(self, params: List[Type]):
@ -368,6 +429,17 @@ class TTuple(Type):
def __str__(self): def __str__(self):
return f'Tuple[{", ".join(str(p) for p in self.params)}]' return f'Tuple[{", ".join(str(p) for p in self.params)}]'
def __eq__(self, other):
o = other.find()
if isinstance(o, TTuple):
if len(self.params) != len(o.params):
return False
for a, b in zip(self.params, o.params):
if a != b:
return False
return True
return False
TBool = TObj('bool', {}, []) TBool = TObj('bool', {}, [])
TInt = TObj('int', {}, []) TInt = TObj('int', {}, [])

View File

@ -3,41 +3,59 @@ import ast
from ast_visitor import Visitor from ast_visitor import Visitor
from nac3_types import * from nac3_types import *
src = """
var = TVar([TInt, TBool]) b = test_virtual(virtual(bar, Foo))
var2 = TVar([TInt, TBool]) b = test_virtual(virtual(foo, Foo))
b = test_virtual(virtual(foo2, Foo))
"""
foo = TObj('Foo', { foo = TObj('Foo', {
'foo': TFunc([ 'a': TInt,
FuncArg('a', var, False), }, [])
FuncArg('b', var2, False)
], var2, set([var2]))
}, [var])
v = Visitor() foo2 = TObj('Foo2', {
v.assignments['get_x'] = TFunc([FuncArg('in', var, False)], TInt, set([var])) 'a': TInt,
v.assignments['Foo'] = TFunc([FuncArg('a', var, False)], foo, set([var])) }, [])
bar = TObj('Bar', {
'a': TInt,
'b': TInt
}, [], [foo])
type_mapping = {
'Foo': foo,
'Foo2': foo2,
'Bar': bar,
}
prelude = {
'foo': foo,
'foo2': foo2,
'bar': bar,
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set())
}
prelude = set(v.assignments.keys())
print('-----------') print('-----------')
print('prelude') print('prelude')
for key, value in v.assignments.items(): for key, value in prelude.items():
print(f'{key}: {value}') print(f'{key}: {value}')
print('-----------') print('-----------')
src = """ v = Visitor(src, prelude.copy(), lambda x: type_mapping[x])
a = f.foo(1, 2)
b = f.foo(1, True)
c = g.foo(True, 1)
d = g.foo(True, True)
f = Foo(1)
g = Foo(True)
"""
print(src) print(src)
v.visit(ast.parse(src)) v.visit(ast.parse(src))
for a, b in v.virtuals:
assert isinstance(a, TObj)
assert b is a or b in a.parents
print('-----------')
print('calls')
for x in v.calls:
x.check()
print(f'{x.find()}')
print('-----------') print('-----------')
print('assignments') print('assignments')
for key, value in v.assignments.items(): for key, value in v.assignments.items():
@ -45,11 +63,6 @@ for key, value in v.assignments.items():
value.check() value.check()
print(f'{key}: {value.find()}') print(f'{key}: {value.find()}')
print('-----------')
print('calls')
for x in v.calls:
x.check()
print(f'{x.find()}')
# TODO: Occur check # TODO: Occur check