virtual type
This commit is contained in:
parent
0c8029d7e1
commit
66df55b3d7
@ -4,10 +4,13 @@ from nac3_types import *
|
||||
|
||||
|
||||
class Visitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
def __init__(self, src, assignments, type_parser):
|
||||
super(Visitor, self).__init__()
|
||||
self.assignments = {}
|
||||
self.source = src
|
||||
self.assignments = assignments
|
||||
self.calls = []
|
||||
self.virtuals = []
|
||||
self.type_parser = type_parser
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.visit(node.value)
|
||||
@ -47,6 +50,15 @@ class Visitor(ast.NodeVisitor):
|
||||
node.type = TInt
|
||||
|
||||
def visit_Call(self, node):
|
||||
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
||||
if len(node.args) != 2:
|
||||
raise UnificationError('Incorrect argument number for virtual')
|
||||
self.visit(node.args[0])
|
||||
ty = self.type_parser(ast.get_source_segment(self.source,
|
||||
node.args[1]))
|
||||
self.virtuals.append((node.args[0].type, ty))
|
||||
node.type = TVirtual(ty)
|
||||
return
|
||||
self.visit(node.func)
|
||||
for arg in node.args:
|
||||
self.visit(arg)
|
||||
|
@ -118,7 +118,16 @@ class TVar(Type):
|
||||
x.rank += 1
|
||||
elif isinstance(y, TVar):
|
||||
# check fields
|
||||
if isinstance(x, TObj):
|
||||
if isinstance(x, TVirtual):
|
||||
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
|
||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||
for k, v in y.fields.items():
|
||||
if k not in x.obj.fields:
|
||||
raise UnificationError(
|
||||
f'Cannot unify {y} with {x}')
|
||||
u = x.obj.fields[k]
|
||||
v.unify(u)
|
||||
elif isinstance(x, TObj):
|
||||
if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]:
|
||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||
for k, v in y.fields.items():
|
||||
@ -127,13 +136,13 @@ class TVar(Type):
|
||||
f'Cannot unify {y} with {x}')
|
||||
u = x.fields[k]
|
||||
v.unify(u)
|
||||
if isinstance(x, TList):
|
||||
elif isinstance(x, TList):
|
||||
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]:
|
||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||
for k, v in y.fields.items():
|
||||
assert isinstance(k, int)
|
||||
v.unify(x.param)
|
||||
if isinstance(x, TTuple):
|
||||
elif isinstance(x, TTuple):
|
||||
if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]:
|
||||
raise UnificationError(f'Cannot unify {y} with {x}')
|
||||
for k, v in y.fields.items():
|
||||
@ -145,6 +154,15 @@ class TVar(Type):
|
||||
else:
|
||||
y.unify(x)
|
||||
|
||||
def __eq__(self, other):
|
||||
s = self.find()
|
||||
o = other.find()
|
||||
if not isinstance(s, TVar):
|
||||
return s == o
|
||||
if isinstance(o, TVar):
|
||||
return s.id == o.id
|
||||
return False
|
||||
|
||||
|
||||
class FuncArg:
|
||||
def __init__(self, name, typ, is_optional):
|
||||
@ -280,10 +298,15 @@ class TFunc(Type):
|
||||
|
||||
|
||||
class TObj(Type):
|
||||
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type]):
|
||||
def __init__(self, name: str, fields: Dict[str, Type], params: List[Type],
|
||||
parents=None):
|
||||
self.name = name
|
||||
self.fields = fields
|
||||
self.params = params
|
||||
if parents is None:
|
||||
self.parents = []
|
||||
else:
|
||||
self.parents = parents
|
||||
|
||||
def check(self):
|
||||
for arg in self.fields.values():
|
||||
@ -323,6 +346,38 @@ class TObj(Type):
|
||||
p = ''
|
||||
return self.name + p
|
||||
|
||||
def __eq__(self, other):
|
||||
o = other.find()
|
||||
if isinstance(o, TObj):
|
||||
if self.name != o.name:
|
||||
return False
|
||||
for a, b in zip(self.params, o.params):
|
||||
if a != b:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TVirtual(Type):
|
||||
def __init__(self, obj: TObj):
|
||||
self.obj = obj
|
||||
|
||||
def __eq__(self, other):
|
||||
o = other.find()
|
||||
if isinstance(o, TVirtual):
|
||||
return self == o
|
||||
return False
|
||||
|
||||
def unify(self, other):
|
||||
o = other.find()
|
||||
if isinstance(o, TVirtual):
|
||||
self.obj.unify(o.obj)
|
||||
else:
|
||||
raise UnificationError(f'Cannot unify {self} with {o}')
|
||||
|
||||
def __str__(self):
|
||||
return f'virtual[{self.obj}]'
|
||||
|
||||
|
||||
class TList(Type):
|
||||
def __init__(self, param: Type):
|
||||
@ -343,6 +398,12 @@ class TList(Type):
|
||||
def __str__(self):
|
||||
return f'List[{self.param}]'
|
||||
|
||||
def __eq__(self, other):
|
||||
o = other.find()
|
||||
if isinstance(o, TList):
|
||||
return self.param == o.param
|
||||
return False
|
||||
|
||||
|
||||
class TTuple(Type):
|
||||
def __init__(self, params: List[Type]):
|
||||
@ -368,6 +429,17 @@ class TTuple(Type):
|
||||
def __str__(self):
|
||||
return f'Tuple[{", ".join(str(p) for p in self.params)}]'
|
||||
|
||||
def __eq__(self, other):
|
||||
o = other.find()
|
||||
if isinstance(o, TTuple):
|
||||
if len(self.params) != len(o.params):
|
||||
return False
|
||||
for a, b in zip(self.params, o.params):
|
||||
if a != b:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
TBool = TObj('bool', {}, [])
|
||||
TInt = TObj('int', {}, [])
|
||||
|
@ -3,41 +3,59 @@ import ast
|
||||
from ast_visitor import Visitor
|
||||
from nac3_types import *
|
||||
|
||||
|
||||
var = TVar([TInt, TBool])
|
||||
var2 = TVar([TInt, TBool])
|
||||
src = """
|
||||
b = test_virtual(virtual(bar, Foo))
|
||||
b = test_virtual(virtual(foo, Foo))
|
||||
b = test_virtual(virtual(foo2, Foo))
|
||||
"""
|
||||
|
||||
foo = TObj('Foo', {
|
||||
'foo': TFunc([
|
||||
FuncArg('a', var, False),
|
||||
FuncArg('b', var2, False)
|
||||
], var2, set([var2]))
|
||||
}, [var])
|
||||
'a': TInt,
|
||||
}, [])
|
||||
|
||||
v = Visitor()
|
||||
v.assignments['get_x'] = TFunc([FuncArg('in', var, False)], TInt, set([var]))
|
||||
v.assignments['Foo'] = TFunc([FuncArg('a', var, False)], foo, set([var]))
|
||||
foo2 = TObj('Foo2', {
|
||||
'a': TInt,
|
||||
}, [])
|
||||
|
||||
bar = TObj('Bar', {
|
||||
'a': TInt,
|
||||
'b': TInt
|
||||
}, [], [foo])
|
||||
|
||||
type_mapping = {
|
||||
'Foo': foo,
|
||||
'Foo2': foo2,
|
||||
'Bar': bar,
|
||||
}
|
||||
|
||||
prelude = {
|
||||
'foo': foo,
|
||||
'foo2': foo2,
|
||||
'bar': bar,
|
||||
'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set())
|
||||
}
|
||||
|
||||
prelude = set(v.assignments.keys())
|
||||
print('-----------')
|
||||
print('prelude')
|
||||
for key, value in v.assignments.items():
|
||||
for key, value in prelude.items():
|
||||
print(f'{key}: {value}')
|
||||
print('-----------')
|
||||
|
||||
src = """
|
||||
a = f.foo(1, 2)
|
||||
b = f.foo(1, True)
|
||||
c = g.foo(True, 1)
|
||||
d = g.foo(True, True)
|
||||
|
||||
f = Foo(1)
|
||||
g = Foo(True)
|
||||
"""
|
||||
v = Visitor(src, prelude.copy(), lambda x: type_mapping[x])
|
||||
|
||||
print(src)
|
||||
v.visit(ast.parse(src))
|
||||
|
||||
for a, b in v.virtuals:
|
||||
assert isinstance(a, TObj)
|
||||
assert b is a or b in a.parents
|
||||
|
||||
print('-----------')
|
||||
print('calls')
|
||||
for x in v.calls:
|
||||
x.check()
|
||||
print(f'{x.find()}')
|
||||
|
||||
print('-----------')
|
||||
print('assignments')
|
||||
for key, value in v.assignments.items():
|
||||
@ -45,11 +63,6 @@ for key, value in v.assignments.items():
|
||||
value.check()
|
||||
print(f'{key}: {value.find()}')
|
||||
|
||||
print('-----------')
|
||||
print('calls')
|
||||
for x in v.calls:
|
||||
x.check()
|
||||
print(f'{x.find()}')
|
||||
|
||||
|
||||
# TODO: Occur check
|
||||
|
Loading…
Reference in New Issue
Block a user