""" :class:`Inferencer` performs unification-based inference on a typedtree. """ from collections import OrderedDict from pythonparser import algorithm, diagnostic, ast from .. import asttyped, types, builtins class Inferencer(algorithm.Visitor): """ :class:`Inferencer` infers types by recursively applying the unification algorithm. It does not treat inability to infer a concrete type as an error; the result can still contain type variables. :class:`Inferencer` is idempotent, but does not guarantee that it will perform all possible inference in a single pass. """ def __init__(self, engine): self.engine = engine self.function = None # currently visited function, for Return inference self.in_loop = False self.has_return = False def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""): try: typea.unify(typeb) except types.UnificationError as e: printer = types.TypePrinter() if makenotes: notes = makenotes(printer, typea, typeb, loca, locb) else: notes = [ diagnostic.Diagnostic("note", "expression of type {typea}", {"typea": printer.name(typea)}, loca) ] if locb: notes.append( diagnostic.Diagnostic("note", "expression of type {typeb}", {"typeb": printer.name(typeb)}, locb)) highlights = [locb] if locb else [] if e.typea.find() == typea.find() and e.typeb.find() == typeb.find() or \ e.typeb.find() == typea.find() and e.typea.find() == typeb.find(): diag = diagnostic.Diagnostic("error", "cannot unify {typea} with {typeb}{when}", {"typea": printer.name(typea), "typeb": printer.name(typeb), "when": when}, loca, highlights, notes) else: # give more detail diag = diagnostic.Diagnostic("error", "cannot unify {typea} with {typeb}{when}: {fraga} is incompatible with {fragb}", {"typea": printer.name(typea), "typeb": printer.name(typeb), "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb), "when": when}, loca, highlights, notes) self.engine.process(diag) # makenotes for the case where types of multiple elements are unified # with the type of parent expression def _makenotes_elts(self, elts, kind): def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "{kind} of type {typea}", {"kind": kind, "typea": printer.name(elts[0].type)}, elts[0].loc), diagnostic.Diagnostic("note", "{kind} of type {typeb}", {"kind": kind, "typeb": printer.name(typeb)}, locb) ] return makenotes def visit_ListT(self, node): self.generic_visit(node) elt_type_loc = node.loc for elt in node.elts: self._unify(node.type["elt"], elt.type, elt_type_loc, elt.loc, self._makenotes_elts(node.elts, "a list element")) elt_type_loc = elt.loc def visit_AttributeT(self, node): self.generic_visit(node) self._unify_attribute(result_type=node.type, value_node=node.value, attr_name=node.attr, attr_loc=node.attr_loc, loc=node.loc) def _unify_method_self(self, method_type, attr_name, attr_loc, loc, self_loc): self_type = types.get_method_self(method_type) function_type = types.get_method_function(method_type) if len(function_type.args) < 1: diag = diagnostic.Diagnostic("error", "function '{attr}{type}' of class '{class}' cannot accept a self argument", {"attr": attr_name, "type": types.TypePrinter().name(function_type), "class": self_type.name}, loc) self.engine.process(diag) else: def makenotes(printer, typea, typeb, loca, locb): if attr_loc is None: msgb = "reference to an instance with a method '{attr}{typeb}'" else: msgb = "reference to a method '{attr}{typeb}'" return [ diagnostic.Diagnostic("note", "expression of type {typea}", {"typea": printer.name(typea)}, loca), diagnostic.Diagnostic("note", msgb, {"attr": attr_name, "typeb": printer.name(function_type)}, locb) ] self._unify(self_type, list(function_type.args.values())[0], self_loc, loc, makenotes=makenotes, when=" while inferring the type for self argument") def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc): object_type = value_node.type.find() if not types.is_var(object_type): if attr_name in object_type.attributes: def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "expression of type {typea}", {"typea": printer.name(typea)}, loca), diagnostic.Diagnostic("note", "expression of type {typeb}", {"typeb": printer.name(object_type)}, value_node.loc) ] attr_type = object_type.attributes[attr_name] self._unify(result_type, attr_type, loc, None, makenotes=makenotes, when=" for attribute '{}'".format(attr_name)) elif types.is_instance(object_type) and \ attr_name in object_type.constructor.attributes: attr_type = object_type.constructor.attributes[attr_name].find() if types.is_function(attr_type): # Convert to a method. attr_type = types.TMethod(object_type, attr_type) self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc) elif types.is_rpc(attr_type): # Convert to a method. We don't have to bother typechecking # the self argument, since for RPCs anything goes. attr_type = types.TMethod(object_type, attr_type) if not types.is_var(attr_type): self._unify(result_type, attr_type, loc, None) else: if attr_loc.source_buffer == value_node.loc.source_buffer: highlights, notes = [value_node.loc], [] else: # This happens when the object being accessed is embedded # from the host program. note = diagnostic.Diagnostic("note", "object being accessed", {}, value_node.loc) highlights, notes = [], [note] diag = diagnostic.Diagnostic("error", "type {type} does not have an attribute '{attr}'", {"type": types.TypePrinter().name(object_type), "attr": attr_name}, attr_loc, highlights, notes) self.engine.process(diag) def _unify_iterable(self, element, collection): if builtins.is_iterable(collection.type): rhs_type = collection.type.find() rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type}) self._unify(rhs_wrapped_lhs_type, rhs_type, element.loc, collection.loc) elif not types.is_var(collection.type): diag = diagnostic.Diagnostic("error", "type {type} is not iterable", {"type": types.TypePrinter().name(collection.type)}, collection.loc, []) self.engine.process(diag) def visit_Index(self, node): self.generic_visit(node) value = node.value if types.is_tuple(value.type): diag = diagnostic.Diagnostic("error", "multi-dimensional slices are not supported", {}, node.loc, []) self.engine.process(diag) else: self._unify(value.type, builtins.TInt(), value.loc, None) def visit_SliceT(self, node): if (node.lower, node.upper, node.step) == (None, None, None): self._unify(node.type, builtins.TInt32(), node.loc, None) else: self._unify(node.type, builtins.TInt(), node.loc, None) for operand in (node.lower, node.upper, node.step): if operand is not None: self._unify(operand.type, node.type, operand.loc, None) def visit_ExtSlice(self, node): diag = diagnostic.Diagnostic("error", "multi-dimensional slices are not supported", {}, node.loc, []) self.engine.process(diag) def visit_SubscriptT(self, node): self.generic_visit(node) if isinstance(node.slice, ast.Index): self._unify_iterable(element=node, collection=node.value) elif isinstance(node.slice, ast.Slice): self._unify(node.type, node.value.type, node.loc, node.value.loc) else: # ExtSlice pass # error emitted above def visit_IfExpT(self, node): self.generic_visit(node) self._unify(node.body.type, node.orelse.type, node.body.loc, node.orelse.loc) self._unify(node.type, node.body.type, node.loc, None) def visit_BoolOpT(self, node): self.generic_visit(node) for value in node.values: self._unify(node.type, value.type, node.loc, value.loc, self._makenotes_elts(node.values, "an operand")) def visit_UnaryOpT(self, node): self.generic_visit(node) operand_type = node.operand.type.find() if isinstance(node.op, ast.Not): self._unify(node.type, builtins.TBool(), node.loc, None) elif isinstance(node.op, ast.Invert): if builtins.is_int(operand_type): self._unify(node.type, operand_type, node.loc, None) elif not types.is_var(operand_type): diag = diagnostic.Diagnostic("error", "expected '~' operand to be of integer type, not {type}", {"type": types.TypePrinter().name(operand_type)}, node.operand.loc) self.engine.process(diag) else: # UAdd, USub if builtins.is_numeric(operand_type): self._unify(node.type, operand_type, node.loc, None) elif not types.is_var(operand_type): diag = diagnostic.Diagnostic("error", "expected unary '{op}' operand to be of numeric type, not {type}", {"op": node.op.loc.source(), "type": types.TypePrinter().name(operand_type)}, node.operand.loc) self.engine.process(diag) def visit_CoerceT(self, node): self.generic_visit(node) if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type): pass else: printer = types.TypePrinter() note = diagnostic.Diagnostic("note", "expression that required coercion to {typeb}", {"typeb": printer.name(node.type)}, node.other_value.loc) diag = diagnostic.Diagnostic("error", "cannot coerce {typea} to {typeb}", {"typea": printer.name(node.value.type), "typeb": printer.name(node.type)}, node.loc, notes=[note]) self.engine.process(diag) def _coerce_one(self, typ, coerced_node, other_node): if coerced_node.type.find() == typ.find(): return coerced_node elif isinstance(coerced_node, asttyped.CoerceT): node = coerced_node node.type, node.other_value = typ, other_node else: node = asttyped.CoerceT(type=typ, value=coerced_node, other_value=other_node, loc=coerced_node.loc) self.visit(node) return node def _coerce_numeric(self, nodes, map_return=lambda typ: typ): # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. node_types = [] for node in nodes: if isinstance(node, asttyped.CoerceT): node_types.append(node.value.type) else: node_types.append(node.type) if any(map(types.is_var, node_types)): # not enough info yet return elif not all(map(builtins.is_numeric, node_types)): err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes)) diag = diagnostic.Diagnostic("error", "cannot coerce {type} to a numeric type", {"type": types.TypePrinter().name(err_node.type)}, err_node.loc, []) self.engine.process(diag) return elif any(map(builtins.is_float, node_types)): typ = builtins.TFloat() elif any(map(builtins.is_int, node_types)): widths = list(map(builtins.get_int_width, node_types)) if all(widths): typ = builtins.TInt(types.TValue(max(widths))) else: typ = builtins.TInt() else: assert False return map_return(typ) def _order_by_pred(self, pred, left, right): if pred(left.type): return left, right elif pred(right.type): return right, left else: assert False def _coerce_binop(self, op, left, right): if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, ast.RShift)): # bitwise operators require integers for operand in (left, right): if not types.is_var(operand.type) and not builtins.is_int(operand.type): diag = diagnostic.Diagnostic("error", "expected '{op}' operand to be of integer type, not {type}", {"op": op.loc.source(), "type": types.TypePrinter().name(operand.type)}, op.loc, [operand.loc]) self.engine.process(diag) return return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, ast.Add): # add works on numbers and also collections if builtins.is_collection(left.type) or builtins.is_collection(right.type): collection, other = \ self._order_by_pred(builtins.is_collection, left, right) if types.is_tuple(collection.type): pred, kind = types.is_tuple, "tuple" elif builtins.is_list(collection.type): pred, kind = builtins.is_list, "list" else: assert False if types.is_var(other.type): return if not pred(other.type): printer = types.TypePrinter() note1 = diagnostic.Diagnostic("note", "{kind} of type {typea}", {"typea": printer.name(collection.type), "kind": kind}, collection.loc) note2 = diagnostic.Diagnostic("note", "{typeb}, which cannot be added to a {kind}", {"typeb": printer.name(other.type), "kind": kind}, other.loc) diag = diagnostic.Diagnostic("error", "expected every '+' operand to be a {kind} in this context", {"kind": kind}, op.loc, [other.loc, collection.loc], [note1, note2]) self.engine.process(diag) return if types.is_tuple(collection.type): return types.TTuple(left.type.find().elts + right.type.find().elts), left.type, right.type elif builtins.is_list(collection.type): self._unify(left.type, right.type, left.loc, right.loc) return left.type, left.type, right.type else: return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, ast.Mult): # mult works on numbers and also number & collection if types.is_tuple(left.type) or types.is_tuple(right.type): tuple_, other = self._order_by_pred(types.is_tuple, left, right) diag = diagnostic.Diagnostic("error", "passing tuples to '*' is not supported", {}, op.loc, [tuple_.loc]) self.engine.process(diag) return elif builtins.is_list(left.type) or builtins.is_list(right.type): list_, other = self._order_by_pred(builtins.is_list, left, right) if not builtins.is_int(other.type): printer = types.TypePrinter() note1 = diagnostic.Diagnostic("note", "list operand of type {typea}", {"typea": printer.name(list_.type)}, list_.loc) note2 = diagnostic.Diagnostic("note", "operand of type {typeb}, which is not a valid repetition amount", {"typeb": printer.name(other.type)}, other.loc) diag = diagnostic.Diagnostic("error", "expected '*' operands to be a list and an integer in this context", {}, op.loc, [list_.loc, other.loc], [note1, note2]) self.engine.process(diag) return return list_.type, left.type, right.type else: return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, (ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)): # numeric operators work on any kind of number return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) elif isinstance(op, ast.Div): # division always returns a float return self._coerce_numeric((left, right), lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat())) else: # MatMult diag = diagnostic.Diagnostic("error", "operator '{op}' is not supported", {"op": op.loc.source()}, op.loc) self.engine.process(diag) return def visit_BinOpT(self, node): self.generic_visit(node) coerced = self._coerce_binop(node.op, node.left, node.right) if coerced: return_type, left_type, right_type = coerced node.left = self._coerce_one(left_type, node.left, other_node=node.right) node.right = self._coerce_one(right_type, node.right, other_node=node.left) def makenotes(printer, typea, typeb, loca, locb): def makenote(typ, coerced, loc): if typ == coerced: return diagnostic.Diagnostic("note", "expression of type {type}", {"type": printer.name(typ)}, loc) else: return diagnostic.Diagnostic("note", "expression of type {typea} (coerced to {typeb})", {"typea": printer.name(typ), "typeb": printer.name(coerced)}, loc) if node.type == return_type: note = diagnostic.Diagnostic("note", "expression of type {type}", {"type": printer.name(typea)}, loca) else: note = diagnostic.Diagnostic("note", "expression of type {typea} (but {typeb} was expected)", {"typea": printer.name(typea), "typeb": printer.name(typeb)}, loca) return [ makenote(node.left.type, left_type, node.left.loc), makenote(node.right.type, right_type, node.right.loc), note ] self._unify(node.type, return_type, node.loc, None, makenotes=makenotes) def visit_CompareT(self, node): self.generic_visit(node) pairs = zip([node.left] + node.comparators, node.comparators) if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)): for left, right in pairs: self._unify(left.type, right.type, left.loc, right.loc) elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)): for left, right in pairs: self._unify_iterable(element=left, collection=right) else: # Eq, NotEq, Lt, LtE, Gt, GtE operands = [node.left] + node.comparators operand_types = [operand.type for operand in operands] if any(map(builtins.is_collection, operand_types)): for left, right in pairs: self._unify(left.type, right.type, left.loc, right.loc) elif any(map(builtins.is_numeric, operand_types)): typ = self._coerce_numeric(operands) if typ: try: other_node = next(filter(lambda operand: operand.type.find() == typ.find(), operands)) except StopIteration: # can't find an argument with an exact type, meaning # the return value is more generic than any of the inputs, meaning # the type is known (typ is not None), but its width is not def wide_enough(opreand): return types.is_mono(opreand.type) and \ opreand.type.find().name == typ.find().name other_node = next(filter(wide_enough, operands)) node.left, *node.comparators = \ [self._coerce_one(typ, operand, other_node) for operand in operands] else: pass # No coercion required. self._unify(node.type, builtins.TBool(), node.loc, None) def visit_ListCompT(self, node): if len(node.generators) > 1: diag = diagnostic.Diagnostic("error", "multiple for clauses in comprehensions are not supported", {}, node.generators[1].for_loc) self.engine.process(diag) self.generic_visit(node) self._unify(node.type, builtins.TList(node.elt.type), node.loc, None) def visit_comprehension(self, node): if any(node.ifs): diag = diagnostic.Diagnostic("error", "if clauses in comprehensions are not supported", {}, node.if_locs[0]) self.engine.process(diag) self.generic_visit(node) self._unify_iterable(element=node.target, collection=node.iter) def visit_builtin_call(self, node): typ = node.func.type.find() def valid_form(signature): return diagnostic.Diagnostic("note", "{func} can be invoked as: {signature}", {"func": typ.name, "signature": signature}, node.func.loc) def diagnose(valid_forms): printer = types.TypePrinter() args = [printer.name(arg.type) for arg in node.args] args += ["%s=%s" % (kw.arg, printer.name(kw.value.type)) for kw in node.keywords] diag = diagnostic.Diagnostic("error", "{func} cannot be invoked with the arguments ({args})", {"func": typ.name, "args": ", ".join(args)}, node.func.loc, notes=valid_forms) self.engine.process(diag) def simple_form(info, arg_types=[], return_type=builtins.TNone()): self._unify(node.type, return_type, node.loc, None) if len(node.args) == len(arg_types) and len(node.keywords) == 0: for index, arg_type in enumerate(arg_types): self._unify(node.args[index].type, arg_type, node.args[index].loc, None) else: diagnose([ valid_form(info) ]) if types.is_exn_constructor(typ): valid_forms = lambda: [ valid_form("{exn}() -> {exn}".format(exn=typ.name)), valid_form("{exn}(message:str) -> {exn}".format(exn=typ.name)), valid_form("{exn}(message:str, param1:int(width=64)) -> {exn}".format(exn=typ.name)), valid_form("{exn}(message:str, param1:int(width=64), " "param2:int(width=64)) -> {exn}".format(exn=typ.name)), valid_form("{exn}(message:str, param1:int(width=64), " "param2:int(width=64), param3:int(width=64)) " "-> {exn}".format(exn=typ.name)), ] if len(node.args) == 0 and len(node.keywords) == 0: pass # Default message, zeroes as parameters elif len(node.args) >= 1 and len(node.args) <= 4 and len(node.keywords) == 0: message, *params = node.args self._unify(message.type, builtins.TStr(), message.loc, None) for param in params: self._unify(param.type, builtins.TInt64(), param.loc, None) else: diagnose(valid_forms()) self._unify(node.type, typ.instance, node.loc, None) elif types.is_builtin(typ, "bool"): valid_forms = lambda: [ valid_form("bool() -> bool"), valid_form("bool(x:'a) -> bool") ] if len(node.args) == 0 and len(node.keywords) == 0: pass # False elif len(node.args) == 1 and len(node.keywords) == 0: arg, = node.args pass # anything goes else: diagnose(valid_forms()) self._unify(node.type, builtins.TBool(), node.loc, None) elif types.is_builtin(typ, "int"): valid_forms = lambda: [ valid_form("int() -> int(width='a)"), valid_form("int(x:'a) -> int(width='b) where 'a is numeric"), valid_form("int(x:'a, width='b:) -> int(width='b) where 'a is numeric") ] self._unify(node.type, builtins.TInt(), node.loc, None) if len(node.args) == 0 and len(node.keywords) == 0: pass # 0 elif len(node.args) == 1 and len(node.keywords) == 0 and \ types.is_var(node.args[0].type): pass # undetermined yet elif len(node.args) == 1 and len(node.keywords) == 0 and \ builtins.is_numeric(node.args[0].type): self._unify(node.type, builtins.TInt(), node.loc, None) elif len(node.args) == 1 and len(node.keywords) == 1 and \ builtins.is_numeric(node.args[0].type) and \ node.keywords[0].arg == 'width': width = node.keywords[0].value if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)): diag = diagnostic.Diagnostic("error", "the width argument of int() must be an integer literal", {}, node.keywords[0].loc) self.engine.process(diag) return self._unify(node.type, builtins.TInt(types.TValue(width.n)), node.loc, None) else: diagnose(valid_forms()) elif types.is_builtin(typ, "float"): valid_forms = lambda: [ valid_form("float() -> float"), valid_form("float(x:'a) -> float where 'a is numeric") ] self._unify(node.type, builtins.TFloat(), node.loc, None) if len(node.args) == 0 and len(node.keywords) == 0: pass # 0.0 elif len(node.args) == 1 and len(node.keywords) == 0 and \ types.is_var(node.args[0].type): pass # undetermined yet elif len(node.args) == 1 and len(node.keywords) == 0 and \ builtins.is_numeric(node.args[0].type): pass else: diagnose(valid_forms()) elif types.is_builtin(typ, "list"): valid_forms = lambda: [ valid_form("list() -> list(elt='a)"), valid_form("list(x:'a) -> list(elt='b) where 'a is iterable") ] self._unify(node.type, builtins.TList(), node.loc, None) if len(node.args) == 0 and len(node.keywords) == 0: pass # [] elif len(node.args) == 1 and len(node.keywords) == 0: arg, = node.args if builtins.is_iterable(arg.type): def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "iterator returning elements of type {typea}", {"typea": printer.name(typea)}, loca), diagnostic.Diagnostic("note", "iterator returning elements of type {typeb}", {"typeb": printer.name(typeb)}, locb) ] self._unify(node.type.find().params["elt"], arg.type.find().params["elt"], node.loc, arg.loc, makenotes=makenotes) elif types.is_var(arg.type): pass # undetermined yet else: note = diagnostic.Diagnostic("note", "this expression has type {type}", {"type": types.TypePrinter().name(arg.type)}, arg.loc) diag = diagnostic.Diagnostic("error", "the argument of list() must be of an iterable type", {}, node.func.loc, notes=[note]) self.engine.process(diag) else: diagnose(valid_forms()) elif types.is_builtin(typ, "range"): valid_forms = lambda: [ valid_form("range(max:int(width='a)) -> range(elt=int(width='a))"), valid_form("range(min:int(width='a), max:int(width='a)) " "-> range(elt=int(width='a))"), valid_form("range(min:int(width='a), max:int(width='a), " "step:int(width='a)) -> range(elt=int(width='a))"), ] range_elt = builtins.TInt(types.TVar()) self._unify(node.type, builtins.TRange(range_elt), node.loc, None) if len(node.args) in (1, 2, 3) and len(node.keywords) == 0: for arg in node.args: self._unify(arg.type, range_elt, arg.loc, None) else: diagnose(valid_forms()) elif types.is_builtin(typ, "len"): valid_forms = lambda: [ valid_form("len(x:'a) -> int(width='b) where 'a is iterable"), ] if len(node.args) == 1 and len(node.keywords) == 0: arg, = node.args if builtins.is_range(arg.type): self._unify(node.type, builtins.get_iterable_elt(arg.type), node.loc, None) elif builtins.is_list(arg.type): # TODO: should be ssize_t-sized self._unify(node.type, builtins.TInt32(), node.loc, None) elif types.is_var(arg.type): pass # undetermined yet else: note = diagnostic.Diagnostic("note", "this expression has type {type}", {"type": types.TypePrinter().name(arg.type)}, arg.loc) diag = diagnostic.Diagnostic("error", "the argument of len() must be of an iterable type", {}, node.func.loc, notes=[note]) self.engine.process(diag) else: diagnose(valid_forms()) elif types.is_builtin(typ, "round"): valid_forms = lambda: [ valid_form("round(x:float) -> int(width='a)"), valid_form("round(x:float, width='b:) -> int(width='b)") ] self._unify(node.type, builtins.TInt(), node.loc, None) if len(node.args) == 1 and len(node.keywords) == 0: arg, = node.args self._unify(arg.type, builtins.TFloat(), arg.loc, None) elif len(node.args) == 1 and len(node.keywords) == 1 and \ builtins.is_numeric(node.args[0].type) and \ node.keywords[0].arg == 'width': width = node.keywords[0].value if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)): diag = diagnostic.Diagnostic("error", "the width argument of round() must be an integer literal", {}, node.keywords[0].loc) self.engine.process(diag) return self._unify(node.type, builtins.TInt(types.TValue(width.n)), node.loc, None) else: diagnose(valid_forms()) elif types.is_builtin(typ, "print"): valid_forms = lambda: [ valid_form("print(args...) -> None"), ] self._unify(node.type, builtins.TNone(), node.loc, None) if len(node.keywords) == 0: # We can print any arguments. pass else: diagnose(valid_forms()) elif types.is_builtin(typ, "rtio_log"): valid_forms = lambda: [ valid_form("rtio_log(channel:str, args...) -> None"), ] self._unify(node.type, builtins.TNone(), node.loc, None) if len(node.args) >= 1 and len(node.keywords) == 0: arg = node.args[0] self._unify(arg.type, builtins.TStr(), arg.loc, None) else: diagnose(valid_forms()) elif types.is_builtin(typ, "now"): simple_form("now() -> float", [], builtins.TFloat()) elif types.is_builtin(typ, "delay"): simple_form("delay(time:float) -> None", [builtins.TFloat()]) elif types.is_builtin(typ, "at"): simple_form("at(time:float) -> None", [builtins.TFloat()]) elif types.is_builtin(typ, "now_mu"): simple_form("now_mu() -> int(width=64)", [], builtins.TInt64()) elif types.is_builtin(typ, "delay_mu"): simple_form("delay_mu(time_mu:int(width=64)) -> None", [builtins.TInt64()]) elif types.is_builtin(typ, "at_mu"): simple_form("at_mu(time_mu:int(width=64)) -> None", [builtins.TInt64()]) elif types.is_builtin(typ, "mu_to_seconds"): simple_form("mu_to_seconds(time_mu:int(width=64)) -> float", [builtins.TInt64()], builtins.TFloat()) elif types.is_builtin(typ, "seconds_to_mu"): simple_form("seconds_to_mu(time:float) -> int(width=64)", [builtins.TFloat()], builtins.TInt64()) elif types.is_builtin(typ, "watchdog"): simple_form("watchdog(time:float) -> [builtin context manager]", [builtins.TFloat()], builtins.TNone()) elif types.is_constructor(typ): # An user-defined class. self._unify(node.type, typ.find().instance, node.loc, None) elif types.is_builtin(typ, "kernel"): # Ignored. self._unify(node.type, builtins.TNone(), node.loc, None) else: assert False def visit_CallT(self, node): self.generic_visit(node) for (sigil_loc, vararg) in ((node.star_loc, node.starargs), (node.dstar_loc, node.kwargs)): if vararg: diag = diagnostic.Diagnostic("error", "variadic arguments are not supported", {}, sigil_loc, [vararg.loc]) self.engine.process(diag) return typ = node.func.type.find() if types.is_var(typ): return # not enough info yet elif types.is_builtin(typ): return self.visit_builtin_call(node) elif types.is_rpc(typ): self._unify(node.type, typ.ret, node.loc, None) return elif not (types.is_function(typ) or types.is_method(typ)): diag = diagnostic.Diagnostic("error", "cannot call this expression of type {type}", {"type": types.TypePrinter().name(typ)}, node.func.loc, []) self.engine.process(diag) return if types.is_function(typ): typ_arity = typ.arity() typ_args = typ.args typ_optargs = typ.optargs typ_ret = typ.ret else: typ = types.get_method_function(typ) if types.is_var(typ): return # not enough info yet elif types.is_rpc(typ): self._unify(node.type, typ.ret, node.loc, None) return elif typ.arity() == 0: return # error elsewhere typ_arity = typ.arity() - 1 typ_args = OrderedDict(list(typ.args.items())[1:]) typ_optargs = typ.optargs typ_ret = typ.ret passed_args = dict() if len(node.args) > typ_arity: note = diagnostic.Diagnostic("note", "extraneous argument(s)", {}, node.args[typ_arity].loc.join(node.args[-1].loc)) diag = diagnostic.Diagnostic("error", "this function of type {type} accepts at most {num} arguments", {"type": types.TypePrinter().name(node.func.type), "num": typ_arity}, node.func.loc, [], [note]) self.engine.process(diag) return for actualarg, (formalname, formaltyp) in \ zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): self._unify(actualarg.type, formaltyp, actualarg.loc, None) passed_args[formalname] = actualarg.loc for keyword in node.keywords: if keyword.arg in passed_args: diag = diagnostic.Diagnostic("error", "the argument '{name}' has been passed earlier as positional", {"name": keyword.arg}, keyword.arg_loc, [passed_args[keyword.arg]]) self.engine.process(diag) return if keyword.arg in typ_args: self._unify(keyword.value.type, typ_args[keyword.arg], keyword.value.loc, None) elif keyword.arg in typ_optargs: self._unify(keyword.value.type, typ_optargs[keyword.arg], keyword.value.loc, None) passed_args[keyword.arg] = keyword.arg_loc for formalname in typ_args: if formalname not in passed_args: note = diagnostic.Diagnostic("note", "the called function is of type {type}", {"type": types.TypePrinter().name(node.func.type)}, node.func.loc) diag = diagnostic.Diagnostic("error", "mandatory argument '{name}' is not passed", {"name": formalname}, node.begin_loc.join(node.end_loc), [], [note]) self.engine.process(diag) return self._unify(node.type, typ_ret, node.loc, None) def visit_LambdaT(self, node): self.generic_visit(node) signature_type = self._type_from_arguments(node.args, node.body.type) if signature_type: self._unify(node.type, signature_type, node.loc, None) def visit_Assign(self, node): self.generic_visit(node) for target in node.targets: self._unify(target.type, node.value.type, target.loc, node.value.loc) def visit_AugAssign(self, node): self.generic_visit(node) coerced = self._coerce_binop(node.op, node.target, node.value) if coerced: return_type, target_type, value_type = coerced try: node.target.type.unify(target_type) except types.UnificationError as e: printer = types.TypePrinter() note = diagnostic.Diagnostic("note", "expression of type {typec}", {"typec": printer.name(node.value.type)}, node.value.loc) diag = diagnostic.Diagnostic("error", "expression of type {typea} has to be coerced to {typeb}, " "which makes assignment invalid", {"typea": printer.name(node.target.type), "typeb": printer.name(target_type)}, node.op.loc, [node.target.loc], [note]) self.engine.process(diag) return try: node.target.type.unify(return_type) except types.UnificationError as e: printer = types.TypePrinter() note = diagnostic.Diagnostic("note", "expression of type {typec}", {"typec": printer.name(node.value.type)}, node.value.loc) diag = diagnostic.Diagnostic("error", "the result of this operation has type {typeb}, " "which makes assignment to a slot of type {typea} invalid", {"typea": printer.name(node.target.type), "typeb": printer.name(return_type)}, node.op.loc, [node.target.loc], [note]) self.engine.process(diag) return node.value = self._coerce_one(value_type, node.value, other_node=node.target) def visit_ForT(self, node): old_in_loop, self.in_loop = self.in_loop, True self.generic_visit(node) self.in_loop = old_in_loop self._unify_iterable(node.target, node.iter) def visit_While(self, node): old_in_loop, self.in_loop = self.in_loop, True self.generic_visit(node) self.in_loop = old_in_loop def visit_Break(self, node): if not self.in_loop: diag = diagnostic.Diagnostic("error", "break statement outside of a loop", {}, node.keyword_loc) self.engine.process(diag) def visit_Continue(self, node): if not self.in_loop: diag = diagnostic.Diagnostic("error", "continue statement outside of a loop", {}, node.keyword_loc) self.engine.process(diag) def visit_withitemT(self, node): self.generic_visit(node) typ = node.context_expr.type if (types.is_builtin(typ, "interleave") or types.is_builtin(typ, "sequential") or types.is_builtin(typ, "parallel") or (isinstance(node.context_expr, asttyped.CallT) and types.is_builtin(node.context_expr.func.type, "watchdog"))): # builtin context managers if node.optional_vars is not None: self._unify(node.optional_vars.type, builtins.TNone(), node.optional_vars.loc, None) elif types.is_instance(typ) or types.is_constructor(typ): # user-defined context managers self._unify_attribute(result_type=node.enter_type, value_node=node.context_expr, attr_name='__enter__', attr_loc=None, loc=node.loc) self._unify_attribute(result_type=node.exit_type, value_node=node.context_expr, attr_name='__exit__', attr_loc=None, loc=node.loc) printer = types.TypePrinter() def check_callback(attr_name, typ, arity): if types.is_var(typ): return if not (types.is_method(typ) or types.is_function(typ)): diag = diagnostic.Diagnostic("error", "attribute '{attr}' of type {manager_type} must be a function", {"attr": attr_name, "manager_type": printer.name(node.context_expr.type)}, node.context_expr.loc) self.engine.process(diag) return if types.is_method(typ): typ = types.get_method_function(typ).find() else: typ = typ.find() if not (len(typ.args) == arity and len(typ.optargs) == 0): diag = diagnostic.Diagnostic("error", "function '{attr}{attr_type}' must accept " "{arity} positional argument{s} and no optional arguments", {"attr": attr_name, "attr_type": printer.name(typ), "arity": arity, "s": "s" if arity > 1 else ""}, node.context_expr.loc) self.engine.process(diag) for formal_arg_name in list(typ.args)[1:]: formal_arg_type = typ.args[formal_arg_name] def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "exception handling via context managers is not supported; " "the argument '{arg}' of function '{attr}{attr_type}' " "will always be None", {"arg": formal_arg_name, "attr": attr_name, "attr_type": printer.name(typ)}, loca), ] self._unify(formal_arg_type, builtins.TNone(), node.context_expr.loc, None, makenotes=makenotes) check_callback('__enter__', node.enter_type, 1) check_callback('__exit__', node.exit_type, 4) if node.optional_vars is not None: if types.is_method(node.exit_type): var_type = types.get_method_function(node.exit_type).find().ret else: var_type = node.exit_type.find().ret def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "expression of type {typea}", {"typea": printer.name(typea)}, loca), diagnostic.Diagnostic("note", "context manager with an '__enter__' method returning {typeb}", {"typeb": printer.name(typeb)}, locb) ] self._unify(node.optional_vars.type, var_type, node.optional_vars.loc, node.context_expr.loc, makenotes=makenotes) elif not types.is_var(typ): diag = diagnostic.Diagnostic("error", "value of type {type} cannot act as a context manager", {"type": types.TypePrinter().name(typ)}, node.context_expr.loc) self.engine.process(diag) def visit_With(self, node): self.generic_visit(node) for item_node in node.items: typ = item_node.context_expr.type.find() if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "interleave") or types.is_builtin(typ, "sequential")) and len(node.items) != 1: diag = diagnostic.Diagnostic("error", "the '{kind}' context manager must be the only one in a 'with' statement", {"kind": typ.name}, node.keyword_loc.join(node.colon_loc)) self.engine.process(diag) def visit_ExceptHandlerT(self, node): self.generic_visit(node) if node.filter is not None: if not types.is_exn_constructor(node.filter.type): diag = diagnostic.Diagnostic("error", "this expression must refer to an exception constructor", {"type": types.TypePrinter().name(node.filter.type)}, node.filter.loc) self.engine.process(diag) else: def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "expression of type {typea}", {"typea": printer.name(typea)}, loca), diagnostic.Diagnostic("note", "constructor of an exception of type {typeb}", {"typeb": printer.name(typeb)}, locb) ] self._unify(node.name_type, node.filter.type.instance, node.name_loc, node.filter.loc, makenotes) def _type_from_arguments(self, node, ret): self.generic_visit(node) for (sigil_loc, vararg) in ((node.star_loc, node.vararg), (node.dstar_loc, node.kwarg)): if vararg: diag = diagnostic.Diagnostic("error", "variadic arguments are not supported", {}, sigil_loc, [vararg.loc]) self.engine.process(diag) return def extract_args(arg_nodes): args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes] return OrderedDict(args) return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]), extract_args(node.args[len(node.args) - len(node.defaults):]), ret) def visit_arguments(self, node): self.generic_visit(node) for arg, default in zip(node.args[len(node.args) - len(node.defaults):], node.defaults): self._unify(arg.type, default.type, arg.loc, default.loc) def visit_FunctionDefT(self, node): for index, decorator in enumerate(node.decorator_list): if types.is_builtin(decorator.type, "kernel") or \ isinstance(decorator, asttyped.CallT) and \ types.is_builtin(decorator.func.type, "kernel"): continue diag = diagnostic.Diagnostic("error", "decorators are not supported", {}, node.at_locs[index], []) self.engine.process(diag) try: old_function, self.function = self.function, node old_in_loop, self.in_loop = self.in_loop, False old_has_return, self.has_return = self.has_return, False self.generic_visit(node) # Lack of return statements is not the only case where the return # type cannot be inferred. The other one is infinite (possibly mutual) # recursion. Since Python functions don't have to return a value, # we ignore that one. if not self.has_return: def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "function with return type {typea}", {"typea": printer.name(typea)}, node.name_loc), ] self._unify(node.return_type, builtins.TNone(), node.name_loc, None, makenotes) finally: self.function = old_function self.in_loop = old_in_loop self.has_return = old_has_return signature_type = self._type_from_arguments(node.args, node.return_type) if signature_type: self._unify(node.signature_type, signature_type, node.name_loc, None) visit_QuotedFunctionDefT = visit_FunctionDefT def visit_ClassDefT(self, node): if any(node.decorator_list): diag = diagnostic.Diagnostic("error", "decorators are not supported", {}, node.at_locs[0], [node.decorator_list[0].loc]) self.engine.process(diag) self.generic_visit(node) def visit_Return(self, node): if not self.function: diag = diagnostic.Diagnostic("error", "return statement outside of a function", {}, node.keyword_loc) self.engine.process(diag) return self.has_return = True self.generic_visit(node) def makenotes(printer, typea, typeb, loca, locb): return [ diagnostic.Diagnostic("note", "function with return type {typea}", {"typea": printer.name(typea)}, self.function.name_loc), diagnostic.Diagnostic("note", "a statement returning {typeb}", {"typeb": printer.name(typeb)}, node.loc) ] if node.value is None: self._unify(self.function.return_type, builtins.TNone(), self.function.name_loc, node.loc, makenotes) else: self._unify(self.function.return_type, node.value.type, self.function.name_loc, node.value.loc, makenotes) def visit_Raise(self, node): self.generic_visit(node) if node.exc is not None: exc_type = node.exc.type if types.is_exn_constructor(exc_type): pass # short form elif not types.is_var(exc_type) and not builtins.is_exception(exc_type): diag = diagnostic.Diagnostic("error", "cannot raise a value of type {type}, which is not an exception", {"type": types.TypePrinter().name(exc_type)}, node.loc) self.engine.process(diag) def visit_Assert(self, node): self.generic_visit(node) self._unify(node.test.type, builtins.TBool(), node.test.loc, None) if node.msg is not None: if not isinstance(node.msg, asttyped.StrT): diag = diagnostic.Diagnostic("error", "assertion message must be a string literal", {}, node.msg.loc) self.engine.process(diag)