diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py index bd453b1..634caff 100644 --- a/hm-inference/ast_visitor.py +++ b/hm-inference/ast_visitor.py @@ -4,10 +4,13 @@ from nac3_types import * class Visitor(ast.NodeVisitor): - def __init__(self): + def __init__(self, src, assignments, type_parser): super(Visitor, self).__init__() - self.assignments = {} + self.source = src + self.assignments = assignments self.calls = [] + self.virtuals = [] + self.type_parser = type_parser def visit_Assign(self, node): self.visit(node.value) @@ -47,6 +50,15 @@ class Visitor(ast.NodeVisitor): node.type = TInt def visit_Call(self, node): + if ast.get_source_segment(self.source, node.func) == 'virtual': + if len(node.args) != 2: + raise UnificationError('Incorrect argument number for virtual') + self.visit(node.args[0]) + ty = self.type_parser(ast.get_source_segment(self.source, + node.args[1])) + self.virtuals.append((node.args[0].type, ty)) + node.type = TVirtual(ty) + return self.visit(node.func) for arg in node.args: self.visit(arg) diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index dcad880..e7d3e18 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -118,7 +118,16 @@ class TVar(Type): x.rank += 1 elif isinstance(y, TVar): # check fields - if isinstance(x, TObj): + if isinstance(x, TVirtual): + if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: + raise UnificationError(f'Cannot unify {y} with {x}') + for k, v in y.fields.items(): + if k not in x.obj.fields: + raise UnificationError( + f'Cannot unify {y} with {x}') + u = x.obj.fields[k] + v.unify(u) + elif isinstance(x, TObj): if y.type not in [TVarType.UNDETERMINED, TVarType.RECORD]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): @@ -127,13 +136,13 @@ class TVar(Type): f'Cannot unify {y} with {x}') u = x.fields[k] v.unify(u) - if isinstance(x, TList): + elif isinstance(x, TList): if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.LIST]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): assert isinstance(k, int) v.unify(x.param) - if isinstance(x, TTuple): + elif isinstance(x, TTuple): if y.type not in [TVarType.UNDETERMINED, TVarType.SEQUENCE, TVarType.TUPLE]: raise UnificationError(f'Cannot unify {y} with {x}') for k, v in y.fields.items(): @@ -145,6 +154,15 @@ class TVar(Type): else: y.unify(x) + def __eq__(self, other): + s = self.find() + o = other.find() + if not isinstance(s, TVar): + return s == o + if isinstance(o, TVar): + return s.id == o.id + return False + class FuncArg: def __init__(self, name, typ, is_optional): @@ -280,10 +298,15 @@ class TFunc(Type): class TObj(Type): - def __init__(self, name: str, fields: Dict[str, Type], params: List[Type]): + def __init__(self, name: str, fields: Dict[str, Type], params: List[Type], + parents=None): self.name = name self.fields = fields self.params = params + if parents is None: + self.parents = [] + else: + self.parents = parents def check(self): for arg in self.fields.values(): @@ -323,6 +346,38 @@ class TObj(Type): p = '' return self.name + p + def __eq__(self, other): + o = other.find() + if isinstance(o, TObj): + if self.name != o.name: + return False + for a, b in zip(self.params, o.params): + if a != b: + return False + return True + return False + + +class TVirtual(Type): + def __init__(self, obj: TObj): + self.obj = obj + + def __eq__(self, other): + o = other.find() + if isinstance(o, TVirtual): + return self == o + return False + + def unify(self, other): + o = other.find() + if isinstance(o, TVirtual): + self.obj.unify(o.obj) + else: + raise UnificationError(f'Cannot unify {self} with {o}') + + def __str__(self): + return f'virtual[{self.obj}]' + class TList(Type): def __init__(self, param: Type): @@ -343,6 +398,12 @@ class TList(Type): def __str__(self): return f'List[{self.param}]' + def __eq__(self, other): + o = other.find() + if isinstance(o, TList): + return self.param == o.param + return False + class TTuple(Type): def __init__(self, params: List[Type]): @@ -368,6 +429,17 @@ class TTuple(Type): def __str__(self): return f'Tuple[{", ".join(str(p) for p in self.params)}]' + def __eq__(self, other): + o = other.find() + if isinstance(o, TTuple): + if len(self.params) != len(o.params): + return False + for a, b in zip(self.params, o.params): + if a != b: + return False + return True + return False + TBool = TObj('bool', {}, []) TInt = TObj('int', {}, []) diff --git a/hm-inference/test.py b/hm-inference/test.py index d58fefa..6ded82d 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -3,41 +3,59 @@ import ast from ast_visitor import Visitor from nac3_types import * - -var = TVar([TInt, TBool]) -var2 = TVar([TInt, TBool]) +src = """ +b = test_virtual(virtual(bar, Foo)) +b = test_virtual(virtual(foo, Foo)) +b = test_virtual(virtual(foo2, Foo)) +""" foo = TObj('Foo', { - 'foo': TFunc([ - FuncArg('a', var, False), - FuncArg('b', var2, False) - ], var2, set([var2])) -}, [var]) + 'a': TInt, +}, []) -v = Visitor() -v.assignments['get_x'] = TFunc([FuncArg('in', var, False)], TInt, set([var])) -v.assignments['Foo'] = TFunc([FuncArg('a', var, False)], foo, set([var])) +foo2 = TObj('Foo2', { + 'a': TInt, +}, []) + +bar = TObj('Bar', { + 'a': TInt, + 'b': TInt +}, [], [foo]) + +type_mapping = { + 'Foo': foo, + 'Foo2': foo2, + 'Bar': bar, +} + +prelude = { + 'foo': foo, + 'foo2': foo2, + 'bar': bar, + 'test_virtual': TFunc([FuncArg('a', TVirtual(foo), False)], TInt, set()) +} -prelude = set(v.assignments.keys()) print('-----------') print('prelude') -for key, value in v.assignments.items(): +for key, value in prelude.items(): print(f'{key}: {value}') print('-----------') -src = """ -a = f.foo(1, 2) -b = f.foo(1, True) -c = g.foo(True, 1) -d = g.foo(True, True) - -f = Foo(1) -g = Foo(True) -""" +v = Visitor(src, prelude.copy(), lambda x: type_mapping[x]) print(src) v.visit(ast.parse(src)) +for a, b in v.virtuals: + assert isinstance(a, TObj) + assert b is a or b in a.parents + +print('-----------') +print('calls') +for x in v.calls: + x.check() + print(f'{x.find()}') + print('-----------') print('assignments') for key, value in v.assignments.items(): @@ -45,11 +63,6 @@ for key, value in v.assignments.items(): value.check() print(f'{key}: {value.find()}') -print('-----------') -print('calls') -for x in v.calls: - x.check() - print(f'{x.find()}') # TODO: Occur check