optional virtual type annotation

This commit is contained in:
pca006132 2021-07-12 09:35:01 +08:00
parent dbf9c17d9f
commit 1ad21f0d67
3 changed files with 13 additions and 10 deletions

View File

@ -80,11 +80,14 @@ class Visitor(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
if ast.get_source_segment(self.source, node.func) == 'virtual': if ast.get_source_segment(self.source, node.func) == 'virtual':
if len(node.args) != 2: if len(node.args) > 2 or len(node.args) < 1:
raise UnificationError('Incorrect argument number for virtual') raise UnificationError('Incorrect argument number for virtual')
self.visit(node.args[0]) self.visit(node.args[0])
if len(node.args) == 2:
ty = self.type_parser(ast.get_source_segment(self.source, ty = self.type_parser(ast.get_source_segment(self.source,
node.args[1])) node.args[1]))
else:
ty = TVar()
self.virtuals.append((node.args[0].type, ty)) self.virtuals.append((node.args[0].type, ty))
node.type = TVirtual(ty) node.type = TVirtual(ty)
return return

View File

@ -377,7 +377,7 @@ class TObj(Type):
class TVirtual(Type): class TVirtual(Type):
def __init__(self, obj: TObj): def __init__(self, obj: Type):
self.obj = obj self.obj = obj
self.checked = False self.checked = False

View File

@ -5,13 +5,13 @@ from nac3_types import *
from primitives import * from primitives import *
src = """ src = """
a = 1 // 1 a = virtual(bar)
b = test_virtual(a)
if 1 <= 1 < 2: a = virtual(foo)
pass b = test_virtual(a)
b = test_virtual(virtual(bar, Foo)) c = virtual(foo, Foo)
b = test_virtual(virtual(foo, Foo))
""" """
foo = TObj('Foo', { foo = TObj('Foo', {
@ -53,7 +53,7 @@ v.visit(ast.parse(src))
for a, b in v.virtuals: for a, b in v.virtuals:
assert isinstance(a, TObj) assert isinstance(a, TObj)
assert b is a or b in a.parents assert b.find() is a or b.find() in a.parents
print('-----------') print('-----------')
print('calls') print('calls')