From 902e1a892cd1bd604111e59f39ca7e22cbf3d80a Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 10 Jul 2021 15:45:53 +0800 Subject: [PATCH] implemented basic operations --- hm-inference/ast_visitor.py | 69 ++++++++++++++++++++++++++++++++++--- hm-inference/nac3_types.py | 5 ++- hm-inference/test.py | 7 ++-- 3 files changed, 74 insertions(+), 7 deletions(-) diff --git a/hm-inference/ast_visitor.py b/hm-inference/ast_visitor.py index bda045e..e8ccebe 100644 --- a/hm-inference/ast_visitor.py +++ b/hm-inference/ast_visitor.py @@ -3,6 +3,33 @@ from itertools import chain from nac3_types import * from primitives import * + +def get_magic_method(op): + if isinstance(op, ast.Add): + return '__add__' + if isinstance(op, ast.Sub): + return '__sub__' + if isinstance(op, ast.Mult): + return '__mul__' + if isinstance(op, ast.Div): + return '__truediv__' + if isinstance(op, ast.FloorDiv): + return '__floordiv__' + if isinstance(op, ast.Eq): + return '__eq__' + if isinstance(op, ast.NotEq): + return '__ne__' + if isinstance(op, ast.Lt): + return '__lt__' + if isinstance(op, ast.LtE): + return '__le__' + if isinstance(op, ast.Gt): + return '__gt__' + if isinstance(op, ast.GtE): + return '__ge__' + raise Exception + + class Visitor(ast.NodeVisitor): def __init__(self, src, assignments, type_parser): super(Visitor, self).__init__() @@ -48,6 +75,8 @@ class Visitor(ast.NodeVisitor): node.type = TBool elif isinstance(node.value, int): node.type = TInt + elif isinstance(node.value, float): + node.type = TFloat def visit_Call(self, node): if ast.get_source_segment(self.source, node.func) == 'virtual': @@ -76,10 +105,7 @@ class Visitor(ast.NodeVisitor): node.type = TVar() v = TVar() v.type = TVarType.RECORD - if node.attr in v.fields: - v.fields[node.attr].unify(node.type) - else: - v.fields[node.attr] = node.type + v.fields[node.attr] = node.type node.value.type.unify(v) def visit_Tuple(self, node): @@ -171,3 +197,38 @@ class Visitor(ast.NodeVisitor): for stmt in chain(node.body, node.orelse): self.visit(stmt) + def visit_BoolOp(self, node): + self.visit(node.values[0]) + self.visit(node.values[1]) + node.values[0].type.unify(TBool) + node.values[1].type.unify(TBool) + node.type = TBool + + def visit_BinOp(self, node): + self.visit(node.left) + self.visit(node.right) + # call method... + method = get_magic_method(node.op) + ret = TVar() + node.type = ret + call = TCall([node.right.type], {}, ret) + self.calls.append(call) + v = TVar() + v.type = TVarType.RECORD + v.fields[method] = call + node.left.type.unify(v) + + def visit_Compare(self, node): + self.visit(node.left) + for c in node.comparators: + self.visit(c) + for a, b, c in zip(chain([node.left], node.comparators[:-1]), + node.comparators, node.ops): + method = get_magic_method(c) + call = TCall([b.type], {}, TBool) + self.calls.append(call) + v = TVar() + v.type = TVarType.RECORD + v.fields[method] = call + a.type.unify(v) + node.type = TBool diff --git a/hm-inference/nac3_types.py b/hm-inference/nac3_types.py index c8d579e..b74800e 100644 --- a/hm-inference/nac3_types.py +++ b/hm-inference/nac3_types.py @@ -331,12 +331,15 @@ class TObj(Type): if len(mapping) == 0: return self new_params = [] + changed = False for v in self.params: if isinstance(v, TVar) and v.id in mapping: new_params.append(mapping[v.id]) + changed = True else: new_params.append(v) - + if not changed: + return self return TObj(self.name, {k: v.subst(mapping) for k, v in self.fields.items()}, new_params) diff --git a/hm-inference/test.py b/hm-inference/test.py index 6e8adea..283a9d3 100644 --- a/hm-inference/test.py +++ b/hm-inference/test.py @@ -5,8 +5,11 @@ from nac3_types import * from primitives import * src = """ -a = 1 -a = a.__add__(2) +a = 1 // 1 + +if 1 <= 1 < 2: + pass + b = test_virtual(virtual(bar, Foo)) b = test_virtual(virtual(foo, Foo)) """