238 lines
8.0 KiB
Python
238 lines
8.0 KiB
Python
import ast
|
|
from itertools import chain
|
|
from nac3_types import *
|
|
from primitives import *
|
|
|
|
|
|
def get_magic_method(op):
|
|
if isinstance(op, ast.Add):
|
|
return '__add__'
|
|
if isinstance(op, ast.Sub):
|
|
return '__sub__'
|
|
if isinstance(op, ast.Mult):
|
|
return '__mul__'
|
|
if isinstance(op, ast.Div):
|
|
return '__truediv__'
|
|
if isinstance(op, ast.FloorDiv):
|
|
return '__floordiv__'
|
|
if isinstance(op, ast.Eq):
|
|
return '__eq__'
|
|
if isinstance(op, ast.NotEq):
|
|
return '__ne__'
|
|
if isinstance(op, ast.Lt):
|
|
return '__lt__'
|
|
if isinstance(op, ast.LtE):
|
|
return '__le__'
|
|
if isinstance(op, ast.Gt):
|
|
return '__gt__'
|
|
if isinstance(op, ast.GtE):
|
|
return '__ge__'
|
|
raise Exception
|
|
|
|
|
|
class Visitor(ast.NodeVisitor):
|
|
def __init__(self, src, assignments, type_parser):
|
|
super(Visitor, self).__init__()
|
|
self.source = src
|
|
self.assignments = assignments
|
|
self.calls = []
|
|
self.virtuals = []
|
|
self.type_parser = type_parser
|
|
|
|
def visit_Assign(self, node):
|
|
self.visit(node.value)
|
|
for target in node.targets:
|
|
self.visit(target)
|
|
target.type.unify(node.value.type)
|
|
|
|
def visit_Name(self, node):
|
|
if node.id == '_':
|
|
node.type = TVar()
|
|
return
|
|
if node.id not in self.assignments:
|
|
self.assignments[node.id] = TVar()
|
|
ty = self.assignments[node.id]
|
|
if isinstance(ty, TFunc):
|
|
ty = ty.instantiate()
|
|
node.type = ty
|
|
|
|
def visit_Lambda(self, node):
|
|
old = self.assignments
|
|
self.assignments = old.copy()
|
|
self.visit(node.args)
|
|
self.visit(node.body)
|
|
self.assignments = old
|
|
node.type = TFunc(node.args.type, node.body.type, [])
|
|
|
|
def visit_arguments(self, node):
|
|
for arg in node.args:
|
|
self.assignments[arg.arg] = TVar()
|
|
arg.type = self.assignments[arg.arg]
|
|
node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args]
|
|
|
|
def visit_Constant(self, node):
|
|
if isinstance(node.value, bool):
|
|
node.type = TBool
|
|
elif isinstance(node.value, int):
|
|
node.type = TInt
|
|
elif isinstance(node.value, float):
|
|
node.type = TFloat
|
|
|
|
def visit_Call(self, node):
|
|
if ast.get_source_segment(self.source, node.func) == 'virtual':
|
|
if len(node.args) > 2 or len(node.args) < 1:
|
|
raise UnificationError('Incorrect argument number for virtual')
|
|
self.visit(node.args[0])
|
|
if len(node.args) == 2:
|
|
ty = self.type_parser(ast.get_source_segment(self.source,
|
|
node.args[1]))
|
|
else:
|
|
ty = TVar()
|
|
self.virtuals.append((node.args[0].type, ty))
|
|
node.type = TVirtual(ty)
|
|
return
|
|
self.visit(node.func)
|
|
for arg in node.args:
|
|
self.visit(arg)
|
|
for keyword in node.keywords:
|
|
self.visit(keyword.value)
|
|
node.type = TVar()
|
|
call = TCall([arg.type for arg in node.args],
|
|
{keyword.arg: keyword.value.type for keyword
|
|
in node.keywords}, node.type)
|
|
node.func.type.unify(call)
|
|
self.calls.append(call)
|
|
|
|
def visit_Attribute(self, node):
|
|
self.visit(node.value)
|
|
node.type = TVar()
|
|
v = TVar()
|
|
v.type = TVarType.RECORD
|
|
v.fields[node.attr] = node.type
|
|
node.value.type.unify(v)
|
|
|
|
def visit_Tuple(self, node):
|
|
for e in node.elts:
|
|
self.visit(e)
|
|
node.type = TTuple([e.type for e in node.elts])
|
|
|
|
def visit_List(self, node):
|
|
ty = TVar()
|
|
for e in node.elts:
|
|
self.visit(e)
|
|
ty.unify(e.type)
|
|
node.type = TList(ty)
|
|
|
|
def visit_Subscript(self, node):
|
|
self.visit(node.value)
|
|
if isinstance(node.slice, ast.Slice):
|
|
node.type = node.value.type
|
|
if isinstance(node.type, TVar):
|
|
node.type.type = node.type.type.unifier(TVarType.LIST)
|
|
elif not isinstance(node.type, TList):
|
|
raise UnificationError(f'{node.type} should be a list')
|
|
elif isinstance(node.slice, ast.ExtSlice):
|
|
raise NotImplementedError()
|
|
else:
|
|
# complicated because we need to handle lists and tuples
|
|
# differently...
|
|
self.visit(node.slice.value)
|
|
node.slice.value.type.unify(TInt)
|
|
ty = node.value.type
|
|
node.type = TVar()
|
|
if isinstance(ty, TVar):
|
|
index = 0
|
|
if isinstance(node.slice.value, ast.Constant):
|
|
seq_ty = TVarType.SEQUENCE
|
|
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
|
|
seq_ty = TVarType.LIST
|
|
ty.type = ty.type.unifier(seq_ty)
|
|
index = node.slice.value.value
|
|
else:
|
|
ty.type = ty.type.unifier(TVarType.LIST)
|
|
if index in ty.fields:
|
|
ty.fields[index].unify(node.type)
|
|
else:
|
|
ty.fields[index] = node.type
|
|
elif isinstance(ty, TList):
|
|
ty.param.unify(node.type)
|
|
elif isinstance(ty, TTuple):
|
|
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
|
|
raise UnificationError(f'Cannot assign to tuple')
|
|
if isinstance(node.slice.value, ast.Constant):
|
|
index = node.slice.value.value
|
|
if index >= len(ty.params):
|
|
raise UnificationError('Index out of range for tuple')
|
|
ty.params[index].unify(node.type)
|
|
else:
|
|
raise UnificationError('Tuple index must be a constant')
|
|
else:
|
|
raise UnificationError(f'Cannot use subscript for {ty}')
|
|
|
|
def visit_For(self, node):
|
|
self.visit(node.target)
|
|
self.visit(node.iter)
|
|
ty = node.iter.type
|
|
if isinstance(ty, TVar):
|
|
# we currently only support iterator over lists
|
|
ty.type = ty.type.unifier(TVarType.LIST)
|
|
if 0 in ty.fields:
|
|
ty.fields[0].unify(node.target.type)
|
|
else:
|
|
ty.fields[0] = node.target.type
|
|
elif isinstance(ty, TList):
|
|
ty.param.unify(node.target.type)
|
|
else:
|
|
raise UnificationError(f'Cannot iterate over {ty}')
|
|
|
|
for stmt in chain(node.body, node.orelse):
|
|
self.visit(stmt)
|
|
|
|
def visit_If(self, node):
|
|
self.visit(node.test)
|
|
node.test.type.unify(TBool)
|
|
for stmt in chain(node.body, node.orelse):
|
|
self.visit(stmt)
|
|
|
|
def visit_While(self, node):
|
|
self.visit(node.test)
|
|
node.test.type.unify(TBool)
|
|
for stmt in chain(node.body, node.orelse):
|
|
self.visit(stmt)
|
|
|
|
def visit_BoolOp(self, node):
|
|
self.visit(node.values[0])
|
|
self.visit(node.values[1])
|
|
node.values[0].type.unify(TBool)
|
|
node.values[1].type.unify(TBool)
|
|
node.type = TBool
|
|
|
|
def visit_BinOp(self, node):
|
|
self.visit(node.left)
|
|
self.visit(node.right)
|
|
# call method...
|
|
method = get_magic_method(node.op)
|
|
ret = TVar()
|
|
node.type = ret
|
|
call = TCall([node.right.type], {}, ret)
|
|
self.calls.append(call)
|
|
v = TVar()
|
|
v.type = TVarType.RECORD
|
|
v.fields[method] = call
|
|
node.left.type.unify(v)
|
|
|
|
def visit_Compare(self, node):
|
|
self.visit(node.left)
|
|
for c in node.comparators:
|
|
self.visit(c)
|
|
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
|
|
node.comparators, node.ops):
|
|
method = get_magic_method(c)
|
|
call = TCall([b.type], {}, TBool)
|
|
self.calls.append(call)
|
|
v = TVar()
|
|
v.type = TVarType.RECORD
|
|
v.fields[method] = call
|
|
a.type.unify(v)
|
|
node.type = TBool
|