diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py index e8ccebe..d1f81ff 100644 --- a/hm-inference/ast_visitor.py +++ b/hm-inference/ast_visitor.py @@ -80,11 +80,14 @@ class Visitor(ast.NodeVisitor): def visit_Call(self, node): if ast.get_source_segment(self.source, node.func) == 'virtual': - if len(node.args) != 2: + if len(node.args) > 2 or len(node.args) < 1: raise UnificationError('Incorrect argument number for virtual') self.visit(node.args[0]) - ty = self.type_parser(ast.get_source_segment(self.source, - node.args[1])) + if len(node.args) == 2: + ty = self.type_parser(ast.get_source_segment(self.source, + node.args[1])) + else: + ty = TVar() self.virtuals.append((node.args[0].type, ty)) node.type = TVirtual(ty) return diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index b74800e..2401c2b 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -377,7 +377,7 @@ class TObj(Type): class TVirtual(Type): - def __init__(self, obj: TObj): + def __init__(self, obj: Type): self.obj = obj self.checked = False diff --git a/hm-inference/test.py b/hm-inference/test.py index 283a9d3..8595807 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -5,13 +5,13 @@ from nac3_types import * from primitives import * src = """ -a = 1 // 1 +a = virtual(bar) +b = test_virtual(a) -if 1 <= 1 < 2: - pass +a = virtual(foo) +b = test_virtual(a) -b = test_virtual(virtual(bar, Foo)) -b = test_virtual(virtual(foo, Foo)) +c = virtual(foo, Foo) """ foo = TObj('Foo', { @@ -53,7 +53,7 @@ v.visit(ast.parse(src)) for a, b in v.virtuals: assert isinstance(a, TObj) - assert b is a or b in a.parents + assert b.find() is a or b.find() in a.parents print('-----------') print('calls')