implemented basic operations

pca
pca006132 2021-07-10 15:45:53 +08:00
parent b1020352ce
commit 902e1a892c
3 changed files with 74 additions and 7 deletions

View File

@ -3,6 +3,33 @@ from itertools import chain
from nac3_types import *
from primitives import *
def get_magic_method(op):
if isinstance(op, ast.Add):
return '__add__'
if isinstance(op, ast.Sub):
return '__sub__'
if isinstance(op, ast.Mult):
return '__mul__'
if isinstance(op, ast.Div):
return '__truediv__'
if isinstance(op, ast.FloorDiv):
return '__floordiv__'
if isinstance(op, ast.Eq):
return '__eq__'
if isinstance(op, ast.NotEq):
return '__ne__'
if isinstance(op, ast.Lt):
return '__lt__'
if isinstance(op, ast.LtE):
return '__le__'
if isinstance(op, ast.Gt):
return '__gt__'
if isinstance(op, ast.GtE):
return '__ge__'
raise Exception
class Visitor(ast.NodeVisitor):
def __init__(self, src, assignments, type_parser):
super(Visitor, self).__init__()
@ -48,6 +75,8 @@ class Visitor(ast.NodeVisitor):
node.type = TBool
elif isinstance(node.value, int):
node.type = TInt
elif isinstance(node.value, float):
node.type = TFloat
def visit_Call(self, node):
if ast.get_source_segment(self.source, node.func) == 'virtual':
@ -76,10 +105,7 @@ class Visitor(ast.NodeVisitor):
node.type = TVar()
v = TVar()
v.type = TVarType.RECORD
if node.attr in v.fields:
v.fields[node.attr].unify(node.type)
else:
v.fields[node.attr] = node.type
v.fields[node.attr] = node.type
node.value.type.unify(v)
def visit_Tuple(self, node):
@ -171,3 +197,38 @@ class Visitor(ast.NodeVisitor):
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_BoolOp(self, node):
self.visit(node.values[0])
self.visit(node.values[1])
node.values[0].type.unify(TBool)
node.values[1].type.unify(TBool)
node.type = TBool
def visit_BinOp(self, node):
self.visit(node.left)
self.visit(node.right)
# call method...
method = get_magic_method(node.op)
ret = TVar()
node.type = ret
call = TCall([node.right.type], {}, ret)
self.calls.append(call)
v = TVar()
v.type = TVarType.RECORD
v.fields[method] = call
node.left.type.unify(v)
def visit_Compare(self, node):
self.visit(node.left)
for c in node.comparators:
self.visit(c)
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
node.comparators, node.ops):
method = get_magic_method(c)
call = TCall([b.type], {}, TBool)
self.calls.append(call)
v = TVar()
v.type = TVarType.RECORD
v.fields[method] = call
a.type.unify(v)
node.type = TBool

View File

@ -331,12 +331,15 @@ class TObj(Type):
if len(mapping) == 0:
return self
new_params = []
changed = False
for v in self.params:
if isinstance(v, TVar) and v.id in mapping:
new_params.append(mapping[v.id])
changed = True
else:
new_params.append(v)
if not changed:
return self
return TObj(self.name, {k: v.subst(mapping) for k, v in
self.fields.items()}, new_params)

View File

@ -5,8 +5,11 @@ from nac3_types import *
from primitives import *
src = """
a = 1
a = a.__add__(2)
a = 1 // 1
if 1 <= 1 < 2:
pass
b = test_virtual(virtual(bar, Foo))
b = test_virtual(virtual(foo, Foo))
"""