import ast from itertools import chain from nac3_types import * from primitives import * def get_magic_method(op): if isinstance(op, ast.Add): return '__add__' if isinstance(op, ast.Sub): return '__sub__' if isinstance(op, ast.Mult): return '__mul__' if isinstance(op, ast.Div): return '__truediv__' if isinstance(op, ast.FloorDiv): return '__floordiv__' if isinstance(op, ast.Eq): return '__eq__' if isinstance(op, ast.NotEq): return '__ne__' if isinstance(op, ast.Lt): return '__lt__' if isinstance(op, ast.LtE): return '__le__' if isinstance(op, ast.Gt): return '__gt__' if isinstance(op, ast.GtE): return '__ge__' raise Exception class Visitor(ast.NodeVisitor): def __init__(self, src, assignments, type_parser): super(Visitor, self).__init__() self.source = src self.assignments = assignments self.calls = [] self.virtuals = [] self.type_parser = type_parser def visit_Assign(self, node): self.visit(node.value) for target in node.targets: self.visit(target) target.type.unify(node.value.type) def visit_Name(self, node): if node.id == '_': node.type = TVar() return if node.id not in self.assignments: self.assignments[node.id] = TVar() ty = self.assignments[node.id] if isinstance(ty, TFunc): ty = ty.instantiate() node.type = ty def visit_Lambda(self, node): old = self.assignments self.assignments = old.copy() self.visit(node.args) self.visit(node.body) self.assignments = old node.type = TFunc(node.args.type, node.body.type, []) def visit_arguments(self, node): for arg in node.args: self.assignments[arg.arg] = TVar() arg.type = self.assignments[arg.arg] node.type = [FuncArg(arg.arg, arg.type, False) for arg in node.args] def visit_Constant(self, node): if isinstance(node.value, bool): node.type = TBool elif isinstance(node.value, int): node.type = TInt elif isinstance(node.value, float): node.type = TFloat def visit_Call(self, node): if ast.get_source_segment(self.source, node.func) == 'virtual': if len(node.args) > 2 or len(node.args) < 1: raise UnificationError('Incorrect argument number for virtual') self.visit(node.args[0]) if len(node.args) == 2: ty = self.type_parser(ast.get_source_segment(self.source, node.args[1])) else: ty = TVar() self.virtuals.append((node.args[0].type, ty)) node.type = TVirtual(ty) return self.visit(node.func) for arg in node.args: self.visit(arg) for keyword in node.keywords: self.visit(keyword.value) node.type = TVar() call = TCall([arg.type for arg in node.args], {keyword.arg: keyword.value.type for keyword in node.keywords}, node.type) node.func.type.unify(call) self.calls.append(call) def visit_Attribute(self, node): self.visit(node.value) node.type = TVar() v = TVar() v.type = TVarType.RECORD v.fields[node.attr] = node.type node.value.type.unify(v) def visit_Tuple(self, node): for e in node.elts: self.visit(e) node.type = TTuple([e.type for e in node.elts]) def visit_List(self, node): ty = TVar() for e in node.elts: self.visit(e) ty.unify(e.type) node.type = TList(ty) def visit_Subscript(self, node): self.visit(node.value) if isinstance(node.slice, ast.Slice): node.type = node.value.type if isinstance(node.type, TVar): node.type.type = node.type.type.unifier(TVarType.LIST) elif not isinstance(node.type, TList): raise UnificationError(f'{node.type} should be a list') elif isinstance(node.slice, ast.ExtSlice): raise NotImplementedError() else: # complicated because we need to handle lists and tuples # differently... self.visit(node.slice.value) node.slice.value.type.unify(TInt) ty = node.value.type node.type = TVar() if isinstance(ty, TVar): index = 0 if isinstance(node.slice.value, ast.Constant): seq_ty = TVarType.SEQUENCE if isinstance(node.ctx, (ast.AugStore, ast.Store)): seq_ty = TVarType.LIST ty.type = ty.type.unifier(seq_ty) index = node.slice.value.value else: ty.type = ty.type.unifier(TVarType.LIST) if index in ty.fields: ty.fields[index].unify(node.type) else: ty.fields[index] = node.type elif isinstance(ty, TList): ty.param.unify(node.type) elif isinstance(ty, TTuple): if isinstance(node.ctx, (ast.AugStore, ast.Store)): raise UnificationError(f'Cannot assign to tuple') if isinstance(node.slice.value, ast.Constant): index = node.slice.value.value if index >= len(ty.params): raise UnificationError('Index out of range for tuple') ty.params[index].unify(node.type) else: raise UnificationError('Tuple index must be a constant') else: raise UnificationError(f'Cannot use subscript for {ty}') def visit_For(self, node): self.visit(node.target) self.visit(node.iter) ty = node.iter.type if isinstance(ty, TVar): # we currently only support iterator over lists ty.type = ty.type.unifier(TVarType.LIST) if 0 in ty.fields: ty.fields[0].unify(node.target.type) else: ty.fields[0] = node.target.type elif isinstance(ty, TList): ty.param.unify(node.target.type) else: raise UnificationError(f'Cannot iterate over {ty}') for stmt in chain(node.body, node.orelse): self.visit(stmt) def visit_If(self, node): self.visit(node.test) node.test.type.unify(TBool) for stmt in chain(node.body, node.orelse): self.visit(stmt) def visit_While(self, node): self.visit(node.test) node.test.type.unify(TBool) for stmt in chain(node.body, node.orelse): self.visit(stmt) def visit_BoolOp(self, node): self.visit(node.values[0]) self.visit(node.values[1]) node.values[0].type.unify(TBool) node.values[1].type.unify(TBool) node.type = TBool def visit_BinOp(self, node): self.visit(node.left) self.visit(node.right) # call method... method = get_magic_method(node.op) ret = TVar() node.type = ret call = TCall([node.right.type], {}, ret) self.calls.append(call) v = TVar() v.type = TVarType.RECORD v.fields[method] = call node.left.type.unify(v) def visit_Compare(self, node): self.visit(node.left) for c in node.comparators: self.visit(c) for a, b, c in zip(chain([node.left], node.comparators[:-1]), node.comparators, node.ops): method = get_magic_method(c) call = TCall([b.type], {}, TBool) self.calls.append(call) v = TVar() v.type = TVarType.RECORD v.fields[method] = call a.type.unify(v) node.type = TBool