virtual type
This commit is contained in:
parent
0c8029d7e1
commit
66df55b3d7
@ -4,10 +4,13 @@ from nac3_types import *
|
|||||||
|
|
||||||
|
|
||||||
class Visitor(ast.NodeVisitor):
|
class Visitor(ast.NodeVisitor):
|
||||||
def __init__(self):
|
def __init__(self, src, assignments, type_parser):
|
||||||
super(Visitor, self).__init__()
|
super(Visitor, self).__init__()
|
||||||
self.assignments = {}
|
self.source = src
|
||||||
|
self.assignments = assignments
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
self.virtuals = []
|
||||||
|
self.type_parser = type_parser
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
self.visit(node.value)
|
self.visit(node.value)
|
||||||
@ -47,6 +50,15 @@ class Visitor(ast.NodeVisitor):
|
|||||||
node.type = TInt
|
node.type = TInt
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
|
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise UnificationError('Incorrect argument number for virtual')
|
||||||
|
self.visit(node.args[0])
|
||||||
|
ty = self.type_parser(ast.get_source_segment(self.source,
|
||||||
|
node.args[1]))
|
||||||
|
self.virtuals.append((node.args[0].type, ty))
|
||||||
|
node.type = TVirtual(ty)
|
||||||
|
return
|
||||||
self.visit(node.func)
|
self.visit(node.func)
|
||||||
for arg in node.args:
|
for arg in node.args:
|
||||||
self.visit(arg)
|
self.visit(arg)
|
||||||
|
@ -118,7 +118,16 @@ class TVar(Type):
|
|||||||
x.rank += 1
|
x.rank += 1
|
||||||
elif isinstance(y, TVar):
|
elif isinstance(y, TVar):
|
||||||
# check fields
|
# check fields
|
||||||
if isinstance(x, TObj):
|
if isinstance(x, TVirtual):
|
||||||
|
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
|
||||||
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
|
for k, v in y.fields.items():
|
||||||
|
if k not in x.obj.fields:
|
||||||
|
raise UnificationError(
|
||||||
|
f'Cannot unify {y} with {x}')
|
||||||
|
u = x.obj.fields[k]
|
||||||
|
v.unify(u)
|
||||||
|
elif isinstance(x, TObj):
|
||||||
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
|
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
|
||||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
for k, v in y.fields.items():
|
for k, v in y.fields.items():
|
||||||
@ -127,13 +136,13 @@ class TVar(Type):
|
|||||||
f'Cannot unify {y} with {x}')
|
f'Cannot unify {y} with {x}')
|
||||||
u = x.fields[k]
|
u = x.fields[k]
|
||||||
v.unify(u)
|
v.unify(u)
|
||||||
if isinstance(x, TList):
|
elif isinstance(x, TList):
|
||||||
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
||||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
for k, v in y.fields.items():
|
for k, v in y.fields.items():
|
||||||
assert isinstance(k, int)
|
assert isinstance(k, int)
|
||||||
v.unify(x.param)
|
v.unify(x.param)
|
||||||
if isinstance(x, TTuple):
|
elif isinstance(x, TTuple):
|
||||||
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
|
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
|
||||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||||
for k, v in y.fields.items():
|
for k, v in y.fields.items():
|
||||||
@ -145,6 +154,15 @@ class TVar(Type):
|
|||||||
else:
|
else:
|
||||||
y.unify(x)
|
y.unify(x)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
s = self.find()
|
||||||
|
o = other.find()
|
||||||
|
if not isinstance(s, TVar):
|
||||||
|
return s == o
|
||||||
|
if isinstance(o, TVar):
|
||||||
|
return s.id == o.id
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class FuncArg:
|
class FuncArg:
|
||||||
def __init__(self, name, typ, is_optional):
|
def __init__(self, name, typ, is_optional):
|
||||||
@ -280,10 +298,15 @@ class TFunc(Type):
|
|||||||
|
|
||||||
|
|
||||||
class TObj(Type):
|
class TObj(Type):
|
||||||
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type]):
|
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type],
|
||||||
|
parents=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.fields = fields
|
self.fields = fields
|
||||||
self.params = params
|
self.params = params
|
||||||
|
if parents is None:
|
||||||
|
self.parents = []
|
||||||
|
else:
|
||||||
|
self.parents = parents
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
for arg in self.fields.values():
|
for arg in self.fields.values():
|
||||||
@ -323,6 +346,38 @@ class TObj(Type):
|
|||||||
p = ''
|
p = ''
|
||||||
return self.name + p
|
return self.name + p
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TObj):
|
||||||
|
if self.name != o.name:
|
||||||
|
return False
|
||||||
|
for a, b in zip(self.params, o.params):
|
||||||
|
if a != b:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TVirtual(Type):
|
||||||
|
def __init__(self, obj: TObj):
|
||||||
|
self.obj = obj
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TVirtual):
|
||||||
|
return self == o
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unify(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TVirtual):
|
||||||
|
self.obj.unify(o.obj)
|
||||||
|
else:
|
||||||
|
raise UnificationError(f'Cannot unify {self} with {o}')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'virtual[{self.obj}]'
|
||||||
|
|
||||||
|
|
||||||
class TList(Type):
|
class TList(Type):
|
||||||
def __init__(self, param: Type):
|
def __init__(self, param: Type):
|
||||||
@ -343,6 +398,12 @@ class TList(Type):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'List[{self.param}]'
|
return f'List[{self.param}]'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TList):
|
||||||
|
return self.param == o.param
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TTuple(Type):
|
class TTuple(Type):
|
||||||
def __init__(self, params: List[Type]):
|
def __init__(self, params: List[Type]):
|
||||||
@ -368,6 +429,17 @@ class TTuple(Type):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'Tuple[{", ".join(str(p) for p in self.params)}]'
|
return f'Tuple[{", ".join(str(p) for p in self.params)}]'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
o = other.find()
|
||||||
|
if isinstance(o, TTuple):
|
||||||
|
if len(self.params) != len(o.params):
|
||||||
|
return False
|
||||||
|
for a, b in zip(self.params, o.params):
|
||||||
|
if a != b:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
TBool = TObj('bool', {}, [])
|
TBool = TObj('bool', {}, [])
|
||||||
TInt = TObj('int', {}, [])
|
TInt = TObj('int', {}, [])
|
||||||
|
@ -3,41 +3,59 @@ import ast
|
|||||||
from ast_visitor import Visitor
|
from ast_visitor import Visitor
|
||||||
from nac3_types import *
|
from nac3_types import *
|
||||||
|
|
||||||
|
src = """
|
||||||
var = TVar([TInt, TBool])
|
b = test_virtual(virtual(bar, Foo))
|
||||||
var2 = TVar([TInt, TBool])
|
b = test_virtual(virtual(foo, Foo))
|
||||||
|
b = test_virtual(virtual(foo2, Foo))
|
||||||
|
"""
|
||||||
|
|
||||||
foo = TObj('Foo', {
|
foo = TObj('Foo', {
|
||||||
'foo': TFunc([
|
'a': TInt,
|
||||||
FuncArg('a', var, False),
|
}, [])
|
||||||
FuncArg('b', var2, False)
|
|
||||||
], var2, set([var2]))
|
|
||||||
}, [var])
|
|
||||||
|
|
||||||
v = Visitor()
|
foo2 = TObj('Foo2', {
|
||||||
v.assignments['get_x'] = TFunc([FuncArg('in', var, False)], TInt, set([var]))
|
'a': TInt,
|
||||||
v.assignments['Foo'] = TFunc([FuncArg('a', var, False)], foo, set([var]))
|
}, [])
|
||||||
|
|
||||||
|
bar = TObj('Bar', {
|
||||||
|
'a': TInt,
|
||||||
|
'b': TInt
|
||||||
|
}, [], [foo])
|
||||||
|
|
||||||
|
type_mapping = {
|
||||||
|
'Foo': foo,
|
||||||
|
'Foo2': foo2,
|
||||||
|
'Bar': bar,
|
||||||
|
}
|
||||||
|
|
||||||
|
prelude = {
|
||||||
|
'foo': foo,
|
||||||
|
'foo2': foo2,
|
||||||
|
'bar': bar,
|
||||||
|
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set())
|
||||||
|
}
|
||||||
|
|
||||||
prelude = set(v.assignments.keys())
|
|
||||||
print('-----------')
|
print('-----------')
|
||||||
print('prelude')
|
print('prelude')
|
||||||
for key, value in v.assignments.items():
|
for key, value in prelude.items():
|
||||||
print(f'{key}: {value}')
|
print(f'{key}: {value}')
|
||||||
print('-----------')
|
print('-----------')
|
||||||
|
|
||||||
src = """
|
v = Visitor(src, prelude.copy(), lambda x: type_mapping[x])
|
||||||
a = f.foo(1, 2)
|
|
||||||
b = f.foo(1, True)
|
|
||||||
c = g.foo(True, 1)
|
|
||||||
d = g.foo(True, True)
|
|
||||||
|
|
||||||
f = Foo(1)
|
|
||||||
g = Foo(True)
|
|
||||||
"""
|
|
||||||
|
|
||||||
print(src)
|
print(src)
|
||||||
v.visit(ast.parse(src))
|
v.visit(ast.parse(src))
|
||||||
|
|
||||||
|
for a, b in v.virtuals:
|
||||||
|
assert isinstance(a, TObj)
|
||||||
|
assert b is a or b in a.parents
|
||||||
|
|
||||||
|
print('-----------')
|
||||||
|
print('calls')
|
||||||
|
for x in v.calls:
|
||||||
|
x.check()
|
||||||
|
print(f'{x.find()}')
|
||||||
|
|
||||||
print('-----------')
|
print('-----------')
|
||||||
print('assignments')
|
print('assignments')
|
||||||
for key, value in v.assignments.items():
|
for key, value in v.assignments.items():
|
||||||
@ -45,11 +63,6 @@ for key, value in v.assignments.items():
|
|||||||
value.check()
|
value.check()
|
||||||
print(f'{key}: {value.find()}')
|
print(f'{key}: {value.find()}')
|
||||||
|
|
||||||
print('-----------')
|
|
||||||
print('calls')
|
|
||||||
for x in v.calls:
|
|
||||||
x.check()
|
|
||||||
print(f'{x.find()}')
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Occur check
|
# TODO: Occur check
|
||||||
|
Loading…
Reference in New Issue
Block a user