nac3-spec/hm-inference/ast_visitor.py

161 lines
5.5 KiB
Python

import ast
from itertools import chain
from nac3_types import *
class Visitor(ast.NodeVisitor):
def __init__(self):
super(Visitor, self).__init__()
self.assignments = {}
self.calls = []
def visit_Assign(self, node):
self.visit(node.value)
for target in node.targets:
self.visit(target)
target.type.unify(node.value.type)
def visit_Name(self, node):
if node.id == '_':
node.type = TVar()
return
if node.id not in self.assignments:
self.assignments[node.id] = TVar()
ty = self.assignments[node.id]
if isinstance(ty, TFunc):
ty = ty.instantiate()
node.type = ty
def visit_Lambda(self, node):
old = self.assignments
self.assignments = old.copy()
self.visit(node.args)
self.visit(node.body)
self.assignments = old
node.type = TFunc(node.args.type, node.body.type, set())
def visit_arguments(self, node):
for arg in node.args:
self.assignments[arg.arg] = TVar()
arg.type = self.assignments[arg.arg]
node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args]
def visit_Constant(self, node):
if isinstance(node.value, bool):
node.type = TBool
elif isinstance(node.value, int):
node.type = TInt
def visit_Call(self, node):
self.visit(node.func)
for arg in node.args:
self.visit(arg)
for keyword in node.keywords:
self.visit(keyword.value)
node.type = TVar()
call = TCall([arg.type for arg in node.args],
{keyword.arg: keyword.value.type for keyword
in node.keywords}, node.type)
node.func.type.unify(call)
self.calls.append(node.func.type)
def visit_Attribute(self, node):
self.visit(node.value)
node.type = TVar()
v = TVar()
v.type = TVarType.RECORD
if node.attr in v.fields:
v.fields[node.attr].unify(node.type)
else:
v.fields[node.attr] = node.type
node.value.type.unify(v)
def visit_Tuple(self, node):
for e in node.elts:
self.visit(e)
node.type = TTuple([e.type for e in node.elts])
def visit_List(self, node):
ty = TVar()
for e in node.elts:
self.visit(e)
ty.unify(e.type)
node.type = TList(ty)
def visit_Subscript(self, node):
self.visit(node.value)
if isinstance(node.slice, ast.Slice):
node.type = node.value.type
if isinstance(node.type, TVar):
node.type.type = node.type.type.unifier(TVarType.LIST)
elif not isinstance(node.type, TList):
raise UnificationError(f'{node.type} should be a list')
elif isinstance(node.slice, ast.ExtSlice):
raise NotImplementedError()
else:
# complicated because we need to handle lists and tuples
# differently...
self.visit(node.slice.value)
node.slice.value.type.unify(TInt)
ty = node.value.type
node.type = TVar()
if isinstance(ty, TVar):
index = 0
if isinstance(node.slice.value, ast.Constant):
seq_ty = TVarType.SEQUENCE
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
seq_ty = TVarType.LIST
ty.type = ty.type.unifier(seq_ty)
index = node.slice.value.value
else:
ty.type = ty.type.unifier(TVarType.LIST)
if index in ty.fields:
ty.fields[index].unify(node.type)
else:
ty.fields[index] = node.type
elif isinstance(ty, TList):
ty.param.unify(node.type)
elif isinstance(ty, TTuple):
if isinstance(node.ctx, (ast.AugStore, ast.Store)):
raise UnificationError(f'Cannot assign to tuple')
if isinstance(node.slice.value, ast.Constant):
index = node.slice.value.value
if index >= len(ty.params):
raise UnificationError('Index out of range for tuple')
ty.params[index].unify(node.type)
else:
raise UnificationError('Tuple index must be a constant')
else:
raise UnificationError(f'Cannot use subscript for {ty}')
def visit_For(self, node):
self.visit(node.target)
self.visit(node.iter)
ty = node.iter.type
if isinstance(ty, TVar):
# we currently only support iterator over lists
ty.type = ty.type.unifier(TVarType.LIST)
if 0 in ty.fields:
ty.fields[0].unify(node.target.type)
else:
ty.fields[0] = node.target.type
elif isinstance(ty, TList):
ty.param.unify(node.target.type)
else:
raise UnificationError(f'Cannot iterate over {ty}')
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_If(self, node):
self.visit(node.test)
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)
def visit_While(self, node):
self.visit(node.test)
node.test.type.unify(TBool)
for stmt in chain(node.body, node.orelse):
self.visit(stmt)