implemented basic operations
This commit is contained in:
parent
b1020352ce
commit
902e1a892c
@ -3,6 +3,33 @@ from itertools import chain
|
|||||||
from nac3_types import *
|
from nac3_types import *
|
||||||
from primitives import *
|
from primitives import *
|
||||||
|
|
||||||
|
|
||||||
|
def get_magic_method(op):
|
||||||
|
if isinstance(op, ast.Add):
|
||||||
|
return '__add__'
|
||||||
|
if isinstance(op, ast.Sub):
|
||||||
|
return '__sub__'
|
||||||
|
if isinstance(op, ast.Mult):
|
||||||
|
return '__mul__'
|
||||||
|
if isinstance(op, ast.Div):
|
||||||
|
return '__truediv__'
|
||||||
|
if isinstance(op, ast.FloorDiv):
|
||||||
|
return '__floordiv__'
|
||||||
|
if isinstance(op, ast.Eq):
|
||||||
|
return '__eq__'
|
||||||
|
if isinstance(op, ast.NotEq):
|
||||||
|
return '__ne__'
|
||||||
|
if isinstance(op, ast.Lt):
|
||||||
|
return '__lt__'
|
||||||
|
if isinstance(op, ast.LtE):
|
||||||
|
return '__le__'
|
||||||
|
if isinstance(op, ast.Gt):
|
||||||
|
return '__gt__'
|
||||||
|
if isinstance(op, ast.GtE):
|
||||||
|
return '__ge__'
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
class Visitor(ast.NodeVisitor):
|
class Visitor(ast.NodeVisitor):
|
||||||
def __init__(self, src, assignments, type_parser):
|
def __init__(self, src, assignments, type_parser):
|
||||||
super(Visitor, self).__init__()
|
super(Visitor, self).__init__()
|
||||||
@ -48,6 +75,8 @@ class Visitor(ast.NodeVisitor):
|
|||||||
node.type = TBool
|
node.type = TBool
|
||||||
elif isinstance(node.value, int):
|
elif isinstance(node.value, int):
|
||||||
node.type = TInt
|
node.type = TInt
|
||||||
|
elif isinstance(node.value, float):
|
||||||
|
node.type = TFloat
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node):
|
||||||
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
||||||
@ -76,9 +105,6 @@ class Visitor(ast.NodeVisitor):
|
|||||||
node.type = TVar()
|
node.type = TVar()
|
||||||
v = TVar()
|
v = TVar()
|
||||||
v.type = TVarType.RECORD
|
v.type = TVarType.RECORD
|
||||||
if node.attr in v.fields:
|
|
||||||
v.fields[node.attr].unify(node.type)
|
|
||||||
else:
|
|
||||||
v.fields[node.attr] = node.type
|
v.fields[node.attr] = node.type
|
||||||
node.value.type.unify(v)
|
node.value.type.unify(v)
|
||||||
|
|
||||||
@ -171,3 +197,38 @@ class Visitor(ast.NodeVisitor):
|
|||||||
for stmt in chain(node.body, node.orelse):
|
for stmt in chain(node.body, node.orelse):
|
||||||
self.visit(stmt)
|
self.visit(stmt)
|
||||||
|
|
||||||
|
def visit_BoolOp(self, node):
|
||||||
|
self.visit(node.values[0])
|
||||||
|
self.visit(node.values[1])
|
||||||
|
node.values[0].type.unify(TBool)
|
||||||
|
node.values[1].type.unify(TBool)
|
||||||
|
node.type = TBool
|
||||||
|
|
||||||
|
def visit_BinOp(self, node):
|
||||||
|
self.visit(node.left)
|
||||||
|
self.visit(node.right)
|
||||||
|
# call method...
|
||||||
|
method = get_magic_method(node.op)
|
||||||
|
ret = TVar()
|
||||||
|
node.type = ret
|
||||||
|
call = TCall([node.right.type], {}, ret)
|
||||||
|
self.calls.append(call)
|
||||||
|
v = TVar()
|
||||||
|
v.type = TVarType.RECORD
|
||||||
|
v.fields[method] = call
|
||||||
|
node.left.type.unify(v)
|
||||||
|
|
||||||
|
def visit_Compare(self, node):
|
||||||
|
self.visit(node.left)
|
||||||
|
for c in node.comparators:
|
||||||
|
self.visit(c)
|
||||||
|
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
|
||||||
|
node.comparators, node.ops):
|
||||||
|
method = get_magic_method(c)
|
||||||
|
call = TCall([b.type], {}, TBool)
|
||||||
|
self.calls.append(call)
|
||||||
|
v = TVar()
|
||||||
|
v.type = TVarType.RECORD
|
||||||
|
v.fields[method] = call
|
||||||
|
a.type.unify(v)
|
||||||
|
node.type = TBool
|
||||||
|
@ -331,12 +331,15 @@ class TObj(Type):
|
|||||||
if len(mapping) == 0:
|
if len(mapping) == 0:
|
||||||
return self
|
return self
|
||||||
new_params = []
|
new_params = []
|
||||||
|
changed = False
|
||||||
for v in self.params:
|
for v in self.params:
|
||||||
if isinstance(v, TVar) and v.id in mapping:
|
if isinstance(v, TVar) and v.id in mapping:
|
||||||
new_params.append(mapping[v.id])
|
new_params.append(mapping[v.id])
|
||||||
|
changed = True
|
||||||
else:
|
else:
|
||||||
new_params.append(v)
|
new_params.append(v)
|
||||||
|
if not changed:
|
||||||
|
return self
|
||||||
return TObj(self.name, {k: v.subst(mapping) for k, v in
|
return TObj(self.name, {k: v.subst(mapping) for k, v in
|
||||||
self.fields.items()}, new_params)
|
self.fields.items()}, new_params)
|
||||||
|
|
||||||
|
@ -5,8 +5,11 @@ from nac3_types import *
|
|||||||
from primitives import *
|
from primitives import *
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
a = 1
|
a = 1 // 1
|
||||||
a = a.__add__(2)
|
|
||||||
|
if 1 <= 1 < 2:
|
||||||
|
pass
|
||||||
|
|
||||||
b = test_virtual(virtual(bar, Foo))
|
b = test_virtual(virtual(bar, Foo))
|
||||||
b = test_virtual(virtual(foo, Foo))
|
b = test_virtual(virtual(foo, Foo))
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user