nac3-spec/hm-inference/ast_visitor.py

238 lines
8.0 KiB
Python
Raw Normal View History

2021-07-09 15:27:02 +08:00
import ast
from itertools import chain
from nac3_types import *
from primitives import *
2021-07-09 15:27:02 +08:00
2021-07-10 15:45:53 +08:00
def get_magic_method(op):
if isinstance(op, ast.Add):
return '__add__'
if isinstance(op, ast.Sub):
return '__sub__'
if isinstance(op, ast.Mult):
return '__mul__'
if isinstance(op, ast.Div):
return '__truediv__'
if isinstance(op, ast.FloorDiv):
return '__floordiv__'
if isinstance(op, ast.Eq):
return '__eq__'
if isinstance(op, ast.NotEq):
return '__ne__'
if isinstance(op, ast.Lt):
return '__lt__'
if isinstance(op, ast.LtE):
return '__le__'
if isinstance(op, ast.Gt):
return '__gt__'
if isinstance(op, ast.GtE):
return '__ge__'
raise Exception
2021-07-09 15:27:02 +08:00
class Visitor(ast.NodeVisitor):
2021-07-10 14:36:28 +08:00
def __init__(self, src, assignments, type_parser):
2021-07-09 15:27:02 +08:00
super(Visitor, self).__init__()
2021-07-10 14:36:28 +08:00
self.source = src
self.assignments = assignments
2021-07-09 15:27:02 +08:00
self.calls = []
2021-07-10 14:36:28 +08:00
self.virtuals = []
self.type_parser = type_parser
2021-07-09 15:27:02 +08:00
def visit_Assign(self, node):
self.visit(node.value)
for target in node.targets:
self.visit(target)
target.type.unify(node.value.type)
def visit_Name(self, node):
if node.id == '_':
node.type = TVar()
return
if node.id not in self.assignments:
self.assignments[node.id] = TVar()
ty = self.assignments[node.id]
if isinstance(ty, TFunc):
ty = ty.instantiate()
node.type = ty
def visit_Lambda(self, node):
old = self.assignments
self.assignments = old.copy()
self.visit(node.args)
self.visit(node.body)
self.assignments = old
node.type = TFunc(node.args.type, node.body.type, [])
2021-07-09 15:27:02 +08:00
def visit_arguments(self, node):
for arg in node.args:
self.assignments[arg.arg] = TVar()
arg.type = self.assignments[arg.arg]
node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args]
def visit_Constant(self, node):
if isinstance(node.value, bool):
node.type = TBool
elif isinstance(node.value, int):
node.type = TInt
2021-07-10 15:45:53 +08:00
elif isinstance(node.value, float):
node.type = TFloat
2021-07-09 15:27:02 +08:00
def visit_Call(self, node):
2021-07-10 14:36:28 +08:00
if ast.get_source_segment(self.source, node.func) == 'virtual':
2021-07-12 09:35:01 +08:00
if len(node.args) > 2 or len(node.args) < 1:
2021-07-10 14:36:28 +08:00
raise UnificationError('Incorrect argument number for virtual')
self.visit(node.args[0])
2021-07-12 09:35:01 +08:00
if len(node.args) == 2:
ty = self.type_parser(ast.get_source_segment(self.source,
node.args[1]))
else:
ty = TVar()
2021-07-10 14:36:28 +08:00
self.virtuals.append((node.args[0].type, ty))
node.type = TVirtual(ty)
return
2021-07-09 15:27:02 +08:00
self.visit(node.func)
for arg in node.args:
self.visit(arg)
for keyword in node.keywords:
self.visit(keyword.value)
node.type = TVar()
call = TCall([arg.type for arg in node.args],
{keyword.arg: keyword.value.type for keyword
in node.keywords}, node.type)
node.func.type.unify(call)
2021-07-09 16:06:06 +08:00
self.calls.append(call)
2021-07-09 15:27:02 +08:00
def visit_Attribute(self, node):
self.visit(node.value)
node.type = TVar()
v = TVar()
v.type = TVarType.RECORD
2021-07-10 15:45:53 +08:00
v.fields[node.attr] = node.type
2021-07-09 15:27:02 +08:00
node.value.type.unify(v)
def visit_Tuple(self, node):
for e in node.elts:
self.visit(e)
node.type = TTuple([e.type for e in node.elts])
def visit_List(self, node):
ty = TVar()
for e in node.elts:
self.visit(e)
ty.unify(e.type)
node.type = TList(ty)
def visit_Subscript(self, node):
self.visit(node.value)
if isinstance(node.slice, ast.Slice):
node.type = node.value.type
if isinstance(node.type, TVar):
node.type.type = node.type.type.unifier(TVarType.LIST)
elif not isinstance(node.type, TList):
raise UnificationError(f'{node.type} should be a list')
elif isinstance(node.slice, ast.ExtSlice):
raise NotImplementedError()
else:
# complicated because we need to handle lists and tuples
# differently...
self.visit(node.slice.value)
node.slice.value.type.unify(TInt)
ty = node.value.type
node.type = TVar()
if isinstance(ty, TVar):
index = 0
if isinstance(node.slice.value, ast.Constant):
seq_ty = TVarType.SEQUENCE
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
seq_ty = TVarType.LIST
ty.type = ty.type.unifier(seq_ty)
index = node.slice.value.value
else:
ty.type = ty.type.unifier(TVarType.LIST)
if index in ty.fields:
ty.fields[index].unify(node.type)
else:
ty.fields[index] = node.type
elif isinstance(ty, TList):
ty.param.unify(node.type)
elif isinstance(ty, TTuple):
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
raise UnificationError(f'Cannot assign to tuple')
if isinstance(node.slice.value, ast.Constant):
index = node.slice.value.value
if index >= len(ty.params):
raise UnificationError('Index out of range for tuple')
ty.params[index].unify(node.type)
else:
raise UnificationError('Tuple index must be a constant')
else:
raise UnificationError(f'Cannot use subscript for {ty}')
def visit_For(self, node):
self.visit(node.target)
self.visit(node.iter)
ty = node.iter.type
if isinstance(ty, TVar):
# we currently only support iterator over lists
ty.type = ty.type.unifier(TVarType.LIST)
if 0 in ty.fields:
ty.fields[0].unify(node.target.type)
else:
ty.fields[0] = node.target.type
elif isinstance(ty, TList):
ty.param.unify(node.target.type)
else:
raise UnificationError(f'Cannot iterate over {ty}')
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_If(self, node):
self.visit(node.test)
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_While(self, node):
self.visit(node.test)
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
2021-07-09 16:06:06 +08:00
2021-07-10 15:45:53 +08:00
def visit_BoolOp(self, node):
self.visit(node.values[0])
self.visit(node.values[1])
node.values[0].type.unify(TBool)
node.values[1].type.unify(TBool)
node.type = TBool
def visit_BinOp(self, node):
self.visit(node.left)
self.visit(node.right)
# call method...
method = get_magic_method(node.op)
ret = TVar()
node.type = ret
call = TCall([node.right.type], {}, ret)
self.calls.append(call)
v = TVar()
v.type = TVarType.RECORD
v.fields[method] = call
node.left.type.unify(v)
def visit_Compare(self, node):
self.visit(node.left)
for c in node.comparators:
self.visit(c)
for a, b, c in zip(chain([node.left], node.comparators[:-1]),
node.comparators, node.ops):
method = get_magic_method(c)
call = TCall([b.type], {}, TBool)
self.calls.append(call)
v = TVar()
v.type = TVarType.RECORD
v.fields[method] = call
a.type.unify(v)
node.type = TBool