implemented basic operations
This commit is contained in:
parent
b1020352ce
commit
902e1a892c
@ -3,6 +3,33 @@ from itertools import chain
|
||||
from nac3_types import *
|
||||
from primitives import *
|
||||
|
||||
|
||||
def get_magic_method(op):
|
||||
if isinstance(op, ast.Add):
|
||||
return '__add__'
|
||||
if isinstance(op, ast.Sub):
|
||||
return '__sub__'
|
||||
if isinstance(op, ast.Mult):
|
||||
return '__mul__'
|
||||
if isinstance(op, ast.Div):
|
||||
return '__truediv__'
|
||||
if isinstance(op, ast.FloorDiv):
|
||||
return '__floordiv__'
|
||||
if isinstance(op, ast.Eq):
|
||||
return '__eq__'
|
||||
if isinstance(op, ast.NotEq):
|
||||
return '__ne__'
|
||||
if isinstance(op, ast.Lt):
|
||||
return '__lt__'
|
||||
if isinstance(op, ast.LtE):
|
||||
return '__le__'
|
||||
if isinstance(op, ast.Gt):
|
||||
return '__gt__'
|
||||
if isinstance(op, ast.GtE):
|
||||
return '__ge__'
|
||||
raise Exception
|
||||
|
||||
|
||||
class Visitor(ast.NodeVisitor):
|
||||
def __init__(self, src, assignments, type_parser):
|
||||
super(Visitor, self).__init__()
|
||||
@ -48,6 +75,8 @@ class Visitor(ast.NodeVisitor):
|
||||
node.type = TBool
|
||||
elif isinstance(node.value, int):
|
||||
node.type = TInt
|
||||
elif isinstance(node.value, float):
|
||||
node.type = TFloat
|
||||
|
||||
def visit_Call(self, node):
|
||||
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
||||
@ -76,10 +105,7 @@ class Visitor(ast.NodeVisitor):
|
||||
node.type = TVar()
|
||||
v = TVar()
|
||||
v.type = TVarType.RECORD
|
||||
if node.attr in v.fields:
|
||||
v.fields[node.attr].unify(node.type)
|
||||
else:
|
||||
v.fields[node.attr] = node.type
|
||||
v.fields[node.attr] = node.type
|
||||
node.value.type.unify(v)
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
@ -171,3 +197,38 @@ class Visitor(ast.NodeVisitor):
|
||||
for stmt in chain(node.body, node.orelse):
|
||||
self.visit(stmt)
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
self.visit(node.values[0])
|
||||
self.visit(node.values[1])
|
||||
node.values[0].type.unify(TBool)
|
||||
node.values[1].type.unify(TBool)
|
||||
node.type = TBool
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
self.visit(node.left)
|
||||
self.visit(node.right)
|
||||
# call method...
|
||||
method = get_magic_method(node.op)
|
||||
ret = TVar()
|
||||
node.type = ret
|
||||
call = TCall([node.right.type], {}, ret)
|
||||
self.calls.append(call)
|
||||
v = TVar()
|
||||
v.type = TVarType.RECORD
|
||||
v.fields[method] = call
|
||||
node.left.type.unify(v)
|
||||
|
||||
def visit_Compare(self, node):
|
||||
self.visit(node.left)
|
||||
for c in node.comparators:
|
||||
self.visit(c)
|
||||
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
|
||||
node.comparators, node.ops):
|
||||
method = get_magic_method(c)
|
||||
call = TCall([b.type], {}, TBool)
|
||||
self.calls.append(call)
|
||||
v = TVar()
|
||||
v.type = TVarType.RECORD
|
||||
v.fields[method] = call
|
||||
a.type.unify(v)
|
||||
node.type = TBool
|
||||
|
@ -331,12 +331,15 @@ class TObj(Type):
|
||||
if len(mapping) == 0:
|
||||
return self
|
||||
new_params = []
|
||||
changed = False
|
||||
for v in self.params:
|
||||
if isinstance(v, TVar) and v.id in mapping:
|
||||
new_params.append(mapping[v.id])
|
||||
changed = True
|
||||
else:
|
||||
new_params.append(v)
|
||||
|
||||
if not changed:
|
||||
return self
|
||||
return TObj(self.name, {k: v.subst(mapping) for k, v in
|
||||
self.fields.items()}, new_params)
|
||||
|
||||
|
@ -5,8 +5,11 @@ from nac3_types import *
|
||||
from primitives import *
|
||||
|
||||
src = """
|
||||
a = 1
|
||||
a = a.__add__(2)
|
||||
a = 1 // 1
|
||||
|
||||
if 1 <= 1 < 2:
|
||||
pass
|
||||
|
||||
b = test_virtual(virtual(bar, Foo))
|
||||
b = test_virtual(virtual(foo, Foo))
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user