import ast from itertools import chain from nac3_types import * class Visitor(ast.NodeVisitor): def __init__(self, src, assignments, type_parser): super(Visitor, self).__init__() self.source = src self.assignments = assignments self.calls = [] self.virtuals = [] self.type_parser = type_parser def visit_Assign(self, node): self.visit(node.value) for target in node.targets: self.visit(target) target.type.unify(node.value.type) def visit_Name(self, node): if node.id == '_': node.type = TVar() return if node.id not in self.assignments: self.assignments[node.id] = TVar() ty = self.assignments[node.id] if isinstance(ty, TFunc): ty = ty.instantiate() node.type = ty def visit_Lambda(self, node): old = self.assignments self.assignments = old.copy() self.visit(node.args) self.visit(node.body) self.assignments = old node.type = TFunc(node.args.type, node.body.type, set()) def visit_arguments(self, node): for arg in node.args: self.assignments[arg.arg] = TVar() arg.type = self.assignments[arg.arg] node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args] def visit_Constant(self, node): if isinstance(node.value, bool): node.type = TBool elif isinstance(node.value, int): node.type = TInt def visit_Call(self, node): if ast.get_source_segment(self.source, node.func) == 'virtual': if len(node.args) != 2: raise UnificationError('Incorrect argument number for virtual') self.visit(node.args[0]) ty = self.type_parser(ast.get_source_segment(self.source, node.args[1])) self.virtuals.append((node.args[0].type, ty)) node.type = TVirtual(ty) return self.visit(node.func) for arg in node.args: self.visit(arg) for keyword in node.keywords: self.visit(keyword.value) node.type = TVar() call = TCall([arg.type for arg in node.args], {keyword.arg: keyword.value.type for keyword in node.keywords}, node.type) node.func.type.unify(call) self.calls.append(call) def visit_Attribute(self, node): self.visit(node.value) node.type = TVar() v = TVar() v.type = TVarType.RECORD if node.attr in v.fields: v.fields[node.attr].unify(node.type) else: v.fields[node.attr] = node.type node.value.type.unify(v) def visit_Tuple(self, node): for e in node.elts: self.visit(e) node.type = TTuple([e.type for e in node.elts]) def visit_List(self, node): ty = TVar() for e in node.elts: self.visit(e) ty.unify(e.type) node.type = TList(ty) def visit_Subscript(self, node): self.visit(node.value) if isinstance(node.slice, ast.Slice): node.type = node.value.type if isinstance(node.type, TVar): node.type.type = node.type.type.unifier(TVarType.LIST) elif not isinstance(node.type, TList): raise UnificationError(f'{node.type} should be a list') elif isinstance(node.slice, ast.ExtSlice): raise NotImplementedError() else: # complicated because we need to handle lists and tuples # differently... self.visit(node.slice.value) node.slice.value.type.unify(TInt) ty = node.value.type node.type = TVar() if isinstance(ty, TVar): index = 0 if isinstance(node.slice.value, ast.Constant): seq_ty = TVarType.SEQUENCE if isinstance(node.ctx, (ast.AugStore, ast.Store)): seq_ty = TVarType.LIST ty.type = ty.type.unifier(seq_ty) index = node.slice.value.value else: ty.type = ty.type.unifier(TVarType.LIST) if index in ty.fields: ty.fields[index].unify(node.type) else: ty.fields[index] = node.type elif isinstance(ty, TList): ty.param.unify(node.type) elif isinstance(ty, TTuple): if isinstance(node.ctx, (ast.AugStore, ast.Store)): raise UnificationError(f'Cannot assign to tuple') if isinstance(node.slice.value, ast.Constant): index = node.slice.value.value if index >= len(ty.params): raise UnificationError('Index out of range for tuple') ty.params[index].unify(node.type) else: raise UnificationError('Tuple index must be a constant') else: raise UnificationError(f'Cannot use subscript for {ty}') def visit_For(self, node): self.visit(node.target) self.visit(node.iter) ty = node.iter.type if isinstance(ty, TVar): # we currently only support iterator over lists ty.type = ty.type.unifier(TVarType.LIST) if 0 in ty.fields: ty.fields[0].unify(node.target.type) else: ty.fields[0] = node.target.type elif isinstance(ty, TList): ty.param.unify(node.target.type) else: raise UnificationError(f'Cannot iterate over {ty}') for stmt in chain(node.body, node.orelse): self.visit(stmt) def visit_If(self, node): self.visit(node.test) node.test.type.unify(TBool) for stmt in chain(node.body, node.orelse): self.visit(stmt) def visit_While(self, node): self.visit(node.test) node.test.type.unify(TBool) for stmt in chain(node.body, node.orelse): self.visit(stmt)