artiq/artiq/compiler/transforms/inferencer.py

1915 lines
84 KiB
Python

"""
:class:`Inferencer` performs unification-based inference on a typedtree.
"""
from collections import OrderedDict
from pythonparser import algorithm, diagnostic, ast
from .. import asttyped, types, builtins
from .typedtree_printer import TypedtreePrinter
from artiq.experiment import kernel
def is_nested_empty_list(node):
"""If the passed AST node is an empty list, or a regularly nested list thereof,
returns the number of nesting layers, or ``None`` otherwise.
For instance, ``is_nested_empty_list([]) == 1`` and
``is_nested_empty_list([[], []]) == 2``, but
``is_nested_empty_list([[[]], []]) == None`` as the number of nesting layers doesn't
match.
"""
if not isinstance(node, ast.List):
return None
if not node.elts:
return 1
result = is_nested_empty_list(node.elts[0])
if result is None:
return None
for elt in node.elts[:1]:
if result != is_nested_empty_list(elt):
return None
return result + 1
class Inferencer(algorithm.Visitor):
"""
:class:`Inferencer` infers types by recursively applying the unification
algorithm. It does not treat inability to infer a concrete type as an error;
the result can still contain type variables.
:class:`Inferencer` is idempotent, but does not guarantee that it will
perform all possible inference in a single pass.
"""
def __init__(self, engine):
self.engine = engine
self.function = None # currently visited function, for Return inference
self.in_loop = False
self.has_return = False
self.subkernel_arg_types = dict()
def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""):
try:
typea.unify(typeb)
except types.UnificationError as e:
printer = types.TypePrinter()
if makenotes:
notes = makenotes(printer, typea, typeb, loca, locb)
else:
notes = [
diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca)
]
if locb:
notes.append(
diagnostic.Diagnostic("note",
"expression of type {typeb}",
{"typeb": printer.name(typeb)},
locb))
highlights = [locb] if locb else []
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find() or \
e.typeb.find() == typea.find() and e.typea.find() == typeb.find():
diag = diagnostic.Diagnostic("error",
"cannot unify {typea} with {typeb}{when}",
{"typea": printer.name(typea), "typeb": printer.name(typeb),
"when": when},
loca, highlights, notes)
else: # give more detail
diag = diagnostic.Diagnostic("error",
"cannot unify {typea} with {typeb}{when}: {fraga} is incompatible with {fragb}",
{"typea": printer.name(typea), "typeb": printer.name(typeb),
"fraga": printer.name(e.typea), "fragb": printer.name(e.typeb),
"when": when},
loca, highlights, notes)
self.engine.process(diag)
# makenotes for the case where types of multiple elements are unified
# with the type of parent expression
def _makenotes_elts(self, elts, kind):
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"{kind} of type {typea}",
{"kind": kind, "typea": printer.name(elts[0].type)},
elts[0].loc),
diagnostic.Diagnostic("note",
"{kind} of type {typeb}",
{"kind": kind, "typeb": printer.name(typeb)},
locb)
]
return makenotes
def visit_ListT(self, node):
self.generic_visit(node)
elt_type_loc = node.loc
for elt in node.elts:
self._unify(node.type["elt"], elt.type,
elt_type_loc, elt.loc,
self._makenotes_elts(node.elts, "a list element"))
elt_type_loc = elt.loc
def visit_AttributeT(self, node):
self.generic_visit(node)
self._unify_attribute(result_type=node.type, value_node=node.value,
attr_name=node.attr, attr_loc=node.attr_loc,
loc=node.loc)
def _unify_method_self(self, method_type, attr_name, attr_loc, loc, self_loc):
self_type = types.get_method_self(method_type)
function_type = types.get_method_function(method_type)
if len(function_type.args) < 1:
diag = diagnostic.Diagnostic("error",
"function '{attr}{type}' of class '{class}' cannot accept a self argument",
{"attr": attr_name, "type": types.TypePrinter().name(function_type),
"class": self_type.name},
loc)
self.engine.process(diag)
else:
def makenotes(printer, typea, typeb, loca, locb):
if attr_loc is None:
msgb = "reference to an instance with a method '{attr}{typeb}'"
else:
msgb = "reference to a method '{attr}{typeb}'"
return [
diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca),
diagnostic.Diagnostic("note",
msgb,
{"attr": attr_name,
"typeb": printer.name(function_type)},
locb)
]
self._unify(self_type, list(function_type.args.values())[0],
self_loc, loc,
makenotes=makenotes,
when=" while inferring the type for self argument")
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
object_type = value_node.type.find()
if not types.is_var(object_type):
if attr_name in object_type.attributes:
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca),
diagnostic.Diagnostic("note",
"expression of type {typeb}",
{"typeb": printer.name(object_type)},
value_node.loc)
]
attr_type = object_type.attributes[attr_name]
self._unify(result_type, attr_type, loc, None,
makenotes=makenotes, when=" for attribute '{}'".format(attr_name))
elif types.is_instance(object_type) and \
attr_name in object_type.constructor.attributes:
attr_type = object_type.constructor.attributes[attr_name].find()
if types.is_function(attr_type):
# Convert to a method.
attr_type = types.TMethod(object_type, attr_type)
self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc)
elif types.is_rpc(attr_type) or types.is_subkernel(attr_type):
# Convert to a method. We don't have to bother typechecking
# the self argument, since for RPCs anything goes.
attr_type = types.TMethod(object_type, attr_type)
if not types.is_var(attr_type):
self._unify(result_type, attr_type,
loc, None)
else:
if attr_loc.source_buffer == value_node.loc.source_buffer:
highlights, notes = [value_node.loc], []
else:
# This happens when the object being accessed is embedded
# from the host program.
note = diagnostic.Diagnostic("note",
"object being accessed", {},
value_node.loc)
highlights, notes = [], [note]
diag = diagnostic.Diagnostic("error",
"type {type} does not have an attribute '{attr}'",
{"type": types.TypePrinter().name(object_type), "attr": attr_name},
attr_loc, highlights, notes)
self.engine.process(diag)
def _unify_iterable(self, element, collection):
if builtins.is_bytes(collection.type) or builtins.is_bytearray(collection.type):
self._unify(element.type, builtins.get_iterable_elt(collection.type),
element.loc, None)
elif builtins.is_array(collection.type):
array_type = collection.type.find()
elem_dims = array_type["num_dims"].value - 1
if elem_dims > 0:
elem_type = builtins.TArray(array_type["elt"], types.TValue(elem_dims))
else:
elem_type = array_type["elt"]
self._unify(element.type, elem_type, element.loc, collection.loc)
elif builtins.is_iterable(collection.type) and not builtins.is_str(collection.type):
rhs_type = collection.type.find()
rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type})
self._unify(rhs_wrapped_lhs_type, rhs_type,
element.loc, collection.loc)
elif not types.is_var(collection.type):
diag = diagnostic.Diagnostic("error",
"type {type} is not iterable",
{"type": types.TypePrinter().name(collection.type)},
collection.loc, [])
self.engine.process(diag)
def visit_Index(self, node):
self.generic_visit(node)
value = node.value
if types.is_tuple(value.type):
for elt in value.type.find().elts:
self._unify(elt, builtins.TInt(),
value.loc, None)
else:
self._unify(value.type, builtins.TInt(),
value.loc, None)
def visit_SliceT(self, node):
self.generic_visit(node)
if (node.lower, node.upper, node.step) == (None, None, None):
self._unify(node.type, builtins.TInt32(),
node.loc, None)
else:
self._unify(node.type, builtins.TInt(),
node.loc, None)
for operand in (node.lower, node.upper, node.step):
if operand is not None:
self._unify(operand.type, node.type,
operand.loc, None)
def visit_ExtSlice(self, node):
diag = diagnostic.Diagnostic("error",
"multi-dimensional slices are not supported", {},
node.loc, [])
self.engine.process(diag)
def visit_SubscriptT(self, node):
self.generic_visit(node)
if types.is_tuple(node.value.type):
if (not isinstance(node.slice, ast.Index) or
not isinstance(node.slice.value, ast.Num)):
diag = diagnostic.Diagnostic(
"error", "tuples can only be indexed by a constant", {},
node.slice.loc, []
)
self.engine.process(diag)
return
tuple_type = node.value.type.find()
index = node.slice.value.n
if index < 0 or index >= len(tuple_type.elts):
diag = diagnostic.Diagnostic(
"error",
"index {index} is out of range for tuple of size {size}",
{"index": index, "size": len(tuple_type.elts)},
node.slice.loc, []
)
self.engine.process(diag)
return
self._unify(node.type, tuple_type.elts[index], node.loc, node.value.loc)
elif isinstance(node.slice, ast.Index):
if types.is_tuple(node.slice.value.type):
if types.is_var(node.value.type):
return
if not builtins.is_array(node.value.type):
diag = diagnostic.Diagnostic(
"error",
"multi-dimensional indexing only supported for arrays, not {type}",
{"type": types.TypePrinter().name(node.value.type)},
node.loc, [])
self.engine.process(diag)
return
num_idxs = len(node.slice.value.type.find().elts)
array_type = node.value.type.find()
num_dims = array_type["num_dims"].value
remaining_dims = num_dims - num_idxs
if remaining_dims < 0:
diag = diagnostic.Diagnostic(
"error",
"too many indices for array of dimension {num_dims}",
{"num_dims": num_dims}, node.slice.loc, [])
self.engine.process(diag)
return
if remaining_dims == 0:
self._unify(node.type, array_type["elt"], node.loc,
node.value.loc)
else:
self._unify(
node.type,
builtins.TArray(array_type["elt"], remaining_dims))
else:
self._unify_iterable(element=node, collection=node.value)
elif isinstance(node.slice, ast.Slice):
if builtins.is_array(node.value.type):
if node.slice.step is not None:
diag = diagnostic.Diagnostic(
"error",
"strided slicing not yet supported for NumPy arrays", {},
node.slice.step.loc, [])
self.engine.process(diag)
return
self._unify(node.type, node.value.type, node.loc, node.value.loc)
else: # ExtSlice
pass # error emitted above
def visit_IfExpT(self, node):
self.generic_visit(node)
self._unify(node.test.type, builtins.TBool(), node.test.loc, None)
self._unify(node.body.type, node.orelse.type,
node.body.loc, node.orelse.loc)
self._unify(node.type, node.body.type,
node.loc, None)
def visit_BoolOpT(self, node):
self.generic_visit(node)
for value in node.values:
self._unify(node.type, value.type,
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
def visit_UnaryOpT(self, node):
self.generic_visit(node)
operand_type = node.operand.type.find()
if isinstance(node.op, ast.Not):
self._unify(node.type, builtins.TBool(),
node.loc, None)
elif isinstance(node.op, ast.Invert):
if builtins.is_int(operand_type):
self._unify(node.type, operand_type,
node.loc, None)
elif not types.is_var(operand_type):
diag = diagnostic.Diagnostic("error",
"expected '~' operand to be of integer type, not {type}",
{"type": types.TypePrinter().name(operand_type)},
node.operand.loc)
self.engine.process(diag)
else: # UAdd, USub
if types.is_var(operand_type):
return
if builtins.is_numeric(operand_type):
self._unify(node.type, operand_type, node.loc, None)
return
if builtins.is_array(operand_type):
elt = operand_type.find()["elt"]
if builtins.is_numeric(elt):
self._unify(node.type, operand_type, node.loc, None)
return
if types.is_var(elt):
return
diag = diagnostic.Diagnostic("error",
"expected unary '{op}' operand to be of numeric type, not {type}",
{"op": node.op.loc.source(),
"type": types.TypePrinter().name(operand_type)},
node.operand.loc)
self.engine.process(diag)
def visit_CoerceT(self, node):
self.generic_visit(node)
if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type):
pass
elif (builtins.is_array(node.type) and builtins.is_array(node.value.type)
and builtins.is_numeric(node.type.find()["elt"])
and builtins.is_numeric(node.value.type.find()["elt"])):
pass
else:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"expression that required coercion to {typeb}",
{"typeb": printer.name(node.type)},
node.other_value.loc)
diag = diagnostic.Diagnostic("error",
"cannot coerce {typea} to {typeb}",
{"typea": printer.name(node.value.type), "typeb": printer.name(node.type)},
node.loc, notes=[note])
self.engine.process(diag)
def _coerce_one(self, typ, coerced_node, other_node):
if coerced_node.type.find() == typ.find():
return coerced_node
elif isinstance(coerced_node, asttyped.CoerceT):
node = coerced_node
node.type.unify(typ)
node.other_value = other_node
else:
node = asttyped.CoerceT(type=typ, value=coerced_node, other_value=other_node,
loc=coerced_node.loc)
self.visit(node)
return node
def _coerce_numeric(self, nodes, map_return=lambda typ: typ, map_node_type =lambda typ:typ):
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
node_types = []
for node in nodes:
if isinstance(node, asttyped.CoerceT):
# If we already know exactly what we coerce this value to, use that type,
# or we'll get an unification error in case the coerced type is not the same
# as the type of the coerced value.
# Otherwise, use the potentially more specific subtype when considering possible
# coercions, or we may get stuck.
if node.type.fold(False, lambda acc, ty: acc or types.is_var(ty)):
node_types.append(node.value.type)
else:
node_types.append(node.type)
else:
node_types.append(node.type)
node_types = [map_node_type(typ) for typ in node_types]
if any(map(types.is_var, node_types)): # not enough info yet
return
elif not all(map(builtins.is_numeric, node_types)):
err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes))
diag = diagnostic.Diagnostic("error",
"cannot coerce {type} to a numeric type",
{"type": types.TypePrinter().name(err_node.type)},
err_node.loc, [])
self.engine.process(diag)
return
elif any(map(builtins.is_float, node_types)):
typ = builtins.TFloat()
elif any(map(builtins.is_int, node_types)):
widths = list(map(builtins.get_int_width, node_types))
if all(widths):
typ = builtins.TInt(types.TValue(max(widths)))
else:
typ = builtins.TInt()
else:
assert False
return map_return(typ)
def _order_by_pred(self, pred, left, right):
if pred(left.type):
return left, right
elif pred(right.type):
return right, left
else:
assert False
def _coerce_binary_broadcast_op(self, left, right, map_return_elt, op_loc):
def num_dims(typ):
if builtins.is_array(typ):
# TODO: If number of dimensions is ever made a non-fixed parameter,
# need to acutally unify num_dims in _coerce_binop/….
return typ.find()["num_dims"].value
return 0
left_dims = num_dims(left.type)
right_dims = num_dims(right.type)
if left_dims != right_dims and left_dims != 0 and right_dims != 0:
# Mismatch (only scalar broadcast supported for now).
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
{"num_dims": left_dims}, left.loc)
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
{"num_dims": right_dims}, right.loc)
diag = diagnostic.Diagnostic(
"error", "dimensions of '{op}' array operands must match",
{"op": op_loc.source()}, op_loc, [left.loc, right.loc], [note1, note2])
self.engine.process(diag)
return
def map_node_type(typ):
if not builtins.is_array(typ):
# This is a single value broadcast across the array.
return typ
return typ.find()["elt"]
# Figure out result type, handling broadcasts.
result_dims = left_dims if left_dims else right_dims
def map_return(typ):
elt = map_return_elt(typ)
result = builtins.TArray(elt=elt, num_dims=result_dims)
left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt
right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt
return (result, left, right)
return self._coerce_numeric((left, right),
map_return=map_return,
map_node_type=map_node_type)
def _coerce_binop(self, op, left, right):
if isinstance(op, ast.MatMult):
if types.is_var(left.type) or types.is_var(right.type):
return
def num_dims(operand):
if not builtins.is_array(operand.type):
diag = diagnostic.Diagnostic(
"error",
"expected matrix multiplication operand to be of array type, not {type}",
{
"op": op.loc.source(),
"type": types.TypePrinter().name(operand.type)
}, op.loc, [operand.loc])
self.engine.process(diag)
return
num_dims = operand.type.find()["num_dims"].value
if num_dims not in (1, 2):
diag = diagnostic.Diagnostic(
"error",
"expected matrix multiplication operand to be 1- or 2-dimensional, not {type}",
{
"op": op.loc.source(),
"type": types.TypePrinter().name(operand.type)
}, op.loc, [operand.loc])
self.engine.process(diag)
return
return num_dims
left_dims = num_dims(left)
if not left_dims:
return
right_dims = num_dims(right)
if not right_dims:
return
def map_node_type(typ):
return typ.find()["elt"]
def map_return(typ):
if left_dims == 1:
if right_dims == 1:
result_dims = 0
else:
result_dims = 1
elif right_dims == 1:
result_dims = 1
else:
result_dims = 2
result = typ if result_dims == 0 else builtins.TArray(
typ, result_dims)
return (result, builtins.TArray(typ, left_dims),
builtins.TArray(typ, right_dims))
return self._coerce_numeric((left, right),
map_return=map_return,
map_node_type=map_node_type)
elif builtins.is_array(left.type) or builtins.is_array(right.type):
# Operations on arrays are element-wise (possibly using broadcasting).
# TODO: Allow only for integer arrays.
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
# ast.RShift)
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
ast.Pow, ast.Sub, ast.Div)
if not isinstance(op, allowed_array_ops):
diag = diagnostic.Diagnostic(
"error", "operator '{op}' not valid for array types",
{"op": op.loc.source()}, op.loc)
self.engine.process(diag)
return
def map_result(typ):
if isinstance(op, ast.Div):
return builtins.TFloat()
return typ
return self._coerce_binary_broadcast_op(left, right, map_result, op.loc)
elif isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
ast.LShift, ast.RShift)):
# bitwise operators require integers
for operand in (left, right):
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
diag = diagnostic.Diagnostic("error",
"expected '{op}' operand to be of integer type, not {type}",
{"op": op.loc.source(),
"type": types.TypePrinter().name(operand.type)},
op.loc, [operand.loc])
self.engine.process(diag)
return
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, ast.Add):
# add works on numbers and also collections
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
collection, other = \
self._order_by_pred(builtins.is_collection, left, right)
if types.is_tuple(collection.type):
pred, kind = types.is_tuple, "tuple"
elif builtins.is_list(collection.type):
pred, kind = builtins.is_list, "list"
else:
assert False
if types.is_var(other.type):
return
if not pred(other.type):
printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note",
"{kind} of type {typea}",
{"typea": printer.name(collection.type), "kind": kind},
collection.loc)
note2 = diagnostic.Diagnostic("note",
"{typeb}, which cannot be added to a {kind}",
{"typeb": printer.name(other.type), "kind": kind},
other.loc)
diag = diagnostic.Diagnostic("error",
"expected every '+' operand to be a {kind} in this context",
{"kind": kind},
op.loc, [other.loc, collection.loc],
[note1, note2])
self.engine.process(diag)
return
if types.is_tuple(collection.type):
return types.TTuple(left.type.find().elts +
right.type.find().elts), left.type, right.type
elif builtins.is_list(collection.type):
self._unify(left.type, right.type,
left.loc, right.loc)
return left.type, left.type, right.type
elif (builtins.is_str(left.type) or builtins.is_str(right.type) or
builtins.is_bytes(left.type) or builtins.is_bytes(right.type)):
self._unify(left.type, right.type,
left.loc, right.loc)
return left.type, left.type, right.type
else:
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, ast.Mult):
# mult works on numbers and also number & collection
if types.is_tuple(left.type) or types.is_tuple(right.type):
tuple_, other = self._order_by_pred(types.is_tuple, left, right)
diag = diagnostic.Diagnostic("error",
"passing tuples to '*' is not supported", {},
op.loc, [tuple_.loc])
self.engine.process(diag)
return
elif builtins.is_list(left.type) or builtins.is_list(right.type):
list_, other = self._order_by_pred(builtins.is_list, left, right)
if not builtins.is_int(other.type) and not types.is_var(other.type):
printer = types.TypePrinter()
note1 = diagnostic.Diagnostic("note",
"list operand of type {typea}",
{"typea": printer.name(list_.type)},
list_.loc)
note2 = diagnostic.Diagnostic("note",
"operand of type {typeb}, which is not a valid repetition amount",
{"typeb": printer.name(other.type)},
other.loc)
diag = diagnostic.Diagnostic("error",
"expected '*' operands to be a list and an integer in this context", {},
op.loc, [list_.loc, other.loc],
[note1, note2])
self.engine.process(diag)
return
return list_.type, left.type, right.type
else:
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, (ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
# numeric operators work on any kind of number
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
elif isinstance(op, ast.Div):
# division always returns a float
return self._coerce_numeric((left, right),
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
else:
diag = diagnostic.Diagnostic("error",
"operator '{op}' is not supported", {"op": op.loc.source()},
op.loc)
self.engine.process(diag)
return
def visit_BinOpT(self, node):
self.generic_visit(node)
coerced = self._coerce_binop(node.op, node.left, node.right)
if coerced:
return_type, left_type, right_type = coerced
node.left = self._coerce_one(left_type, node.left, other_node=node.right)
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
def makenotes(printer, typea, typeb, loca, locb):
def makenote(typ, coerced, loc):
if typ == coerced:
return diagnostic.Diagnostic("note",
"expression of type {type}",
{"type": printer.name(typ)},
loc)
else:
return diagnostic.Diagnostic("note",
"expression of type {typea} (coerced to {typeb})",
{"typea": printer.name(typ),
"typeb": printer.name(coerced)},
loc)
if node.type == return_type:
note = diagnostic.Diagnostic("note",
"expression of type {type}",
{"type": printer.name(typea)},
loca)
else:
note = diagnostic.Diagnostic("note",
"expression of type {typea} (but {typeb} was expected)",
{"typea": printer.name(typea),
"typeb": printer.name(typeb)},
loca)
return [
makenote(node.left.type, left_type, node.left.loc),
makenote(node.right.type, right_type, node.right.loc),
note
]
self._unify(node.type, return_type,
node.loc, None,
makenotes=makenotes)
def visit_CompareT(self, node):
self.generic_visit(node)
pairs = zip([node.left] + node.comparators, node.comparators)
if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)):
for left, right in pairs:
self._unify(left.type, right.type,
left.loc, right.loc)
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
for left, right in pairs:
self._unify_iterable(element=left, collection=right)
else: # Eq, NotEq, Lt, LtE, Gt, GtE
operands = [node.left] + node.comparators
operand_types = [operand.type for operand in operands]
if any(map(builtins.is_collection, operand_types)):
for left, right in pairs:
self._unify(left.type, right.type,
left.loc, right.loc)
elif any(map(builtins.is_numeric, operand_types)):
typ = self._coerce_numeric(operands)
if typ:
try:
other_node = next(filter(lambda operand: operand.type.find() == typ.find(),
operands))
except StopIteration:
# can't find an argument with an exact type, meaning
# the return value is more generic than any of the inputs, meaning
# the type is known (typ is not None), but its width is not
def wide_enough(opreand):
return types.is_mono(opreand.type) and \
opreand.type.find().name == typ.find().name
other_node = next(filter(wide_enough, operands))
node.left, *node.comparators = \
[self._coerce_one(typ, operand, other_node) for operand in operands]
else:
pass # No coercion required.
self._unify(node.type, builtins.TBool(),
node.loc, None)
def visit_ListCompT(self, node):
if len(node.generators) > 1:
diag = diagnostic.Diagnostic("error",
"multiple for clauses in comprehensions are not supported", {},
node.generators[1].for_loc)
self.engine.process(diag)
self.generic_visit(node)
self._unify(node.type, builtins.TList(node.elt.type),
node.loc, None)
def visit_comprehension(self, node):
if any(node.ifs):
diag = diagnostic.Diagnostic("error",
"if clauses in comprehensions are not supported", {},
node.if_locs[0])
self.engine.process(diag)
self.generic_visit(node)
self._unify_iterable(element=node.target, collection=node.iter)
def visit_builtin_call(self, node):
typ = node.func.type.find()
def valid_form(signature):
return diagnostic.Diagnostic("note",
"{func} can be invoked as: {signature}",
{"func": typ.name, "signature": signature},
node.func.loc)
def diagnose(valid_forms):
printer = types.TypePrinter()
args = [printer.name(arg.type) for arg in node.args]
args += ["%s=%s" % (kw.arg, printer.name(kw.value.type)) for kw in node.keywords]
diag = diagnostic.Diagnostic("error",
"{func} cannot be invoked with the arguments ({args})",
{"func": typ.name, "args": ", ".join(args)},
node.func.loc, notes=valid_forms)
self.engine.process(diag)
def simple_form(info, arg_types=[], return_type=builtins.TNone()):
self._unify(node.type, return_type,
node.loc, None)
if len(node.args) == len(arg_types) and len(node.keywords) == 0:
for index, arg_type in enumerate(arg_types):
self._unify(node.args[index].type, arg_type,
node.args[index].loc, None)
else:
diagnose([ valid_form(info) ])
if types.is_exn_constructor(typ):
valid_forms = lambda: [
valid_form("{exn}() -> {exn}".format(exn=typ.name)),
valid_form("{exn}(message:str) -> {exn}".format(exn=typ.name)),
valid_form("{exn}(message:str, param1:numpy.int64) -> {exn}".format(exn=typ.name)),
valid_form("{exn}(message:str, param1:numpy.int64, "
"param2:numpy.int64) -> {exn}".format(exn=typ.name)),
valid_form("{exn}(message:str, param1:numpy.int64, "
"param2:numpy.int64, param3:numpy.int64) "
"-> {exn}".format(exn=typ.name)),
]
if len(node.args) == 0 and len(node.keywords) == 0:
pass # Default message, zeroes as parameters
elif len(node.args) >= 1 and len(node.args) <= 4 and len(node.keywords) == 0:
message, *params = node.args
self._unify(message.type, builtins.TStr(),
message.loc, None)
for param in params:
self._unify(param.type, builtins.TInt64(),
param.loc, None)
else:
diagnose(valid_forms())
self._unify(node.type, typ.instance,
node.loc, None)
elif types.is_builtin(typ, "bool"):
valid_forms = lambda: [
valid_form("bool() -> bool"),
valid_form("bool(x:'a) -> bool")
]
if len(node.args) == 0 and len(node.keywords) == 0:
pass # False
elif len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args
pass # anything goes
else:
diagnose(valid_forms())
self._unify(node.type, builtins.TBool(),
node.loc, None)
elif types.is_builtin(typ, "int") or \
types.is_builtin(typ, "int32") or types.is_builtin(typ, "int64"):
if types.is_builtin(typ, "int"):
valid_forms = lambda: [
valid_form("int() -> numpy.int?"),
valid_form("int(x:'a) -> numpy.int? where 'a is numeric")
]
result_typ = builtins.TInt()
elif types.is_builtin(typ, "int32"):
valid_forms = lambda: [
valid_form("numpy.int32() -> numpy.int32"),
valid_form("numpy.int32(x:'a) -> numpy.int32 where 'a is numeric")
]
result_typ = builtins.TInt32()
elif types.is_builtin(typ, "int64"):
valid_forms = lambda: [
valid_form("numpy.int64() -> numpy.int64"),
valid_form("numpy.int64(x:'a) -> numpy.int64 where 'a is numeric")
]
result_typ = builtins.TInt64()
self._unify(node.type, result_typ,
node.loc, None)
if len(node.args) == 0 and len(node.keywords) == 0:
pass # 0
elif len(node.args) == 1 and len(node.keywords) == 0 and \
types.is_var(node.args[0].type):
pass # undetermined yet
elif len(node.args) == 1 and len(node.keywords) == 0 and \
builtins.is_numeric(node.args[0].type):
self._unify(node.type, result_typ,
node.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "float"):
valid_forms = lambda: [
valid_form("float() -> float"),
valid_form("float(x:'a) -> float where 'a is numeric")
]
self._unify(node.type, builtins.TFloat(),
node.loc, None)
if len(node.args) == 0 and len(node.keywords) == 0:
pass # 0.0
elif len(node.args) == 1 and len(node.keywords) == 0 and \
types.is_var(node.args[0].type):
pass # undetermined yet
elif len(node.args) == 1 and len(node.keywords) == 0 and \
builtins.is_numeric(node.args[0].type):
pass
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "str"):
diag = diagnostic.Diagnostic("error",
"strings currently cannot be constructed", {},
node.loc)
self.engine.process(diag)
elif types.is_builtin(typ, "array"):
valid_forms = lambda: [
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable"),
valid_form("array(x:'a, dtype:'b) -> array(elt='b) where 'a is iterable")
]
explicit_dtype = None
keywords_acceptable = False
if len(node.keywords) == 0:
keywords_acceptable = True
elif len(node.keywords) == 1:
if node.keywords[0].arg == "dtype":
keywords_acceptable = True
explicit_dtype = node.keywords[0].value
if len(node.args) == 1 and keywords_acceptable:
arg, = node.args
num_empty_dims = is_nested_empty_list(arg)
if num_empty_dims is not None:
# As a special case, following the behaviour of numpy.array (and
# repr() on ndarrays), consider empty lists to be exactly of the
# number of dimensions given, instead of potentially containing an
# unknown number of extra dimensions.
num_dims = num_empty_dims
# The ultimate element type will be TVar initially, but we might be
# able to resolve it from context.
elt = arg.type
for _ in range(num_dims):
assert builtins.is_list(elt)
elt = elt.find()["elt"]
else:
# In the absence of any other information (there currently isn't a way
# to specify any), assume that all iterables are expandable into a
# (runtime-checked) rectangular array of the innermost element type.
elt = arg.type
num_dims = 0
expected_dims = (node.type.find()["num_dims"].value
if builtins.is_array(node.type) else -1)
while True:
if num_dims == expected_dims:
# If we already know the number of dimensions of the result,
# stop so we can disambiguate the (innermost) element type of
# the argument if it is still unknown.
break
if types.is_var(elt):
# Can't make progress here because we don't know how many more
# dimensions might be "hidden" inside.
return
if not builtins.is_iterable(elt) or builtins.is_str(elt):
break
if builtins.is_array(elt):
num_dims += elt.find()["num_dims"].value
else:
num_dims += 1
elt = builtins.get_iterable_elt(elt)
if explicit_dtype is not None:
# TODO: Factor out type detection; support quoted type constructors
# (TList(TInt32), …)?
typ = explicit_dtype.type
if types.is_builtin(typ, "int32"):
elt = builtins.TInt32()
elif types.is_builtin(typ, "int64"):
elt = builtins.TInt64()
elif types.is_constructor(typ):
elt = typ.find().instance
else:
diag = diagnostic.Diagnostic(
"error",
"dtype argument of {builtin}() must be a valid constructor",
{"builtin": typ.find().name},
node.func.loc,
notes=[note])
self.engine.process(diag)
return
if num_dims == 0:
note = diagnostic.Diagnostic(
"note", "this expression has type {type}",
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
diag = diagnostic.Diagnostic(
"error",
"the argument of {builtin}() must be of an iterable type",
{"builtin": typ.find().name},
node.func.loc,
notes=[note])
self.engine.process(diag)
return
self._unify(node.type,
builtins.TArray(elt, types.TValue(num_dims)),
node.loc, arg.loc)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "list"):
valid_forms = lambda: [
valid_form("list() -> list(elt='a)"),
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
]
self._unify(node.type, builtins.TList(), node.loc, None)
if len(node.args) == 0 and len(node.keywords) == 0:
pass # []
elif len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args
if builtins.is_iterable(arg.type):
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"iterator returning elements of type {typea}",
{"typea": printer.name(typea)},
loca),
diagnostic.Diagnostic("note",
"iterator returning elements of type {typeb}",
{"typeb": printer.name(typeb)},
locb)
]
self._unify(node.type.find().params["elt"],
arg.type.find().params["elt"],
node.loc, arg.loc, makenotes=makenotes)
elif types.is_var(arg.type):
pass # undetermined yet
else:
note = diagnostic.Diagnostic("note",
"this expression has type {type}",
{"type": types.TypePrinter().name(arg.type)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"the argument of {builtin}() must be of an iterable type",
{"builtin": typ.find().name},
node.func.loc, notes=[note])
self.engine.process(diag)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "range"):
valid_forms = lambda: [
valid_form("range(max:numpy.int?) -> range(elt=numpy.int?)"),
valid_form("range(min:numpy.int?, max:numpy.int?) "
"-> range(elt=numpy.int?)"),
valid_form("range(min:numpy.int?, max:numpy.int?, "
"step:numpy.int?) -> range(elt=numpy.int?)"),
]
range_elt = builtins.TInt(types.TVar())
self._unify(node.type, builtins.TRange(range_elt),
node.loc, None)
if len(node.args) in (1, 2, 3) and len(node.keywords) == 0:
for arg in node.args:
self._unify(arg.type, range_elt,
arg.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "len"):
valid_forms = lambda: [
valid_form("len(x:'a) -> numpy.int?"),
]
if len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args
if builtins.is_range(arg.type):
self._unify(node.type, builtins.get_iterable_elt(arg.type),
node.loc, None)
elif builtins.is_listish(arg.type):
# TODO: should be ssize_t-sized
self._unify(node.type, builtins.TInt32(),
node.loc, None)
elif types.is_var(arg.type):
pass # undetermined yet
else:
note = diagnostic.Diagnostic("note",
"this expression has type {type}",
{"type": types.TypePrinter().name(arg.type)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"the argument of len() must be of an iterable type", {},
node.func.loc, notes=[note])
self.engine.process(diag)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "round"):
valid_forms = lambda: [
valid_form("round(x:float) -> numpy.int?"),
]
self._unify(node.type, builtins.TInt(),
node.loc, None)
if len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args
self._unify(arg.type, builtins.TFloat(),
arg.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "abs"):
fn = typ.name
valid_forms = lambda: [
valid_form("abs(x:numpy.int?) -> numpy.int?"),
valid_form("abs(x:float) -> float")
]
if len(node.args) == 1 and len(node.keywords) == 0:
(arg,) = node.args
if builtins.is_int(arg.type) or builtins.is_float(arg.type):
self._unify(arg.type, node.type,
arg.loc, node.loc)
elif types.is_var(arg.type):
pass # undetermined yet
else:
diag = diagnostic.Diagnostic("error",
"the arguments of abs() must be of a numeric type", {},
node.func.loc)
self.engine.process(diag)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):
fn = typ.name
valid_forms = lambda: [
valid_form("{}(x:numpy.int?, y:numpy.int?) -> numpy.int?".format(fn)),
valid_form("{}(x:float, y:float) -> float".format(fn))
]
if len(node.args) == 2 and len(node.keywords) == 0:
arg0, arg1 = node.args
self._unify(arg0.type, arg1.type,
arg0.loc, arg1.loc)
if builtins.is_int(arg0.type) or builtins.is_float(arg0.type):
self._unify(arg0.type, node.type,
arg0.loc, node.loc)
elif types.is_var(arg0.type):
pass # undetermined yet
else:
note = diagnostic.Diagnostic("note",
"this expression has type {type}",
{"type": types.TypePrinter().name(arg0.type)},
arg0.loc)
diag = diagnostic.Diagnostic("error",
"the arguments of {fn}() must be of a numeric type",
{"fn": fn},
node.func.loc, notes=[note])
self.engine.process(diag)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "print"):
valid_forms = lambda: [
valid_form("print(args...) -> None"),
]
self._unify(node.type, builtins.TNone(),
node.loc, None)
if len(node.keywords) == 0:
# We can print any arguments.
pass
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "make_array"):
valid_forms = lambda: [
valid_form("numpy.full(count:int32, value:'a) -> array(elt='a, num_dims=1)"),
valid_form("numpy.full(shape:(int32,)*'b, value:'a) -> array(elt='a, num_dims='b)"),
]
if len(node.args) == 2 and len(node.keywords) == 0:
arg0, arg1 = node.args
if types.is_var(arg0.type):
return # undetermined yet
elif types.is_tuple(arg0.type):
num_dims = len(arg0.type.find().elts)
self._unify(arg0.type, types.TTuple([builtins.TInt32()] * num_dims),
arg0.loc, None)
else:
num_dims = 1
self._unify(arg0.type, builtins.TInt32(),
arg0.loc, None)
self._unify(node.type, builtins.TArray(num_dims=num_dims),
node.loc, None)
self._unify(arg1.type, node.type.find()["elt"],
arg1.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "numpy.transpose"):
valid_forms = lambda: [
valid_form("transpose(x: array(elt='a, num_dims=1)) -> array(elt='a, num_dims=1)"),
valid_form("transpose(x: array(elt='a, num_dims=2)) -> array(elt='a, num_dims=2)")
]
if len(node.args) == 1 and len(node.keywords) == 0:
arg, = node.args
if types.is_var(arg.type):
pass # undetermined yet
elif not builtins.is_array(arg.type):
note = diagnostic.Diagnostic(
"note", "this expression has type {type}",
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
diag = diagnostic.Diagnostic(
"error",
"the argument of {builtin}() must be an array",
{"builtin": typ.find().name},
node.func.loc,
notes=[note])
self.engine.process(diag)
else:
num_dims = arg.type.find()["num_dims"].value
if num_dims not in (1, 2):
note = diagnostic.Diagnostic(
"note", "argument is {num_dims}-dimensional",
{"num_dims": num_dims}, arg.loc)
diag = diagnostic.Diagnostic(
"error",
"{builtin}() is currently only supported for up to "
"two-dimensional arrays", {"builtin": typ.find().name},
node.func.loc,
notes=[note])
self.engine.process(diag)
else:
self._unify(node.type, arg.type, node.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "rtio_log"):
valid_forms = lambda: [
valid_form("rtio_log(channel:str, args...) -> None"),
]
self._unify(node.type, builtins.TNone(),
node.loc, None)
if len(node.args) >= 1 and len(node.keywords) == 0:
arg = node.args[0]
self._unify(arg.type, builtins.TStr(),
arg.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "now"):
simple_form("now() -> float",
[], builtins.TFloat())
elif types.is_builtin(typ, "delay"):
simple_form("delay(time:float) -> None",
[builtins.TFloat()])
elif types.is_builtin(typ, "at"):
simple_form("at(time:float) -> None",
[builtins.TFloat()])
elif types.is_builtin(typ, "now_mu"):
simple_form("now_mu() -> numpy.int64",
[], builtins.TInt64())
elif types.is_builtin(typ, "delay_mu"):
simple_form("delay_mu(time_mu:numpy.int64) -> None",
[builtins.TInt64()])
elif types.is_builtin(typ, "at_mu"):
simple_form("at_mu(time_mu:numpy.int64) -> None",
[builtins.TInt64()])
elif types.is_constructor(typ):
# An user-defined class.
self._unify(node.type, typ.find().instance,
node.loc, None)
elif types.is_builtin(typ, "kernel"):
# Ignored.
self._unify(node.type, builtins.TNone(),
node.loc, None)
elif types.is_builtin(typ, "subkernel_await"):
valid_forms = lambda: [
valid_form("subkernel_await(f: subkernel) -> f return type"),
valid_form("subkernel_await(f: subkernel, timeout: numpy.int64) -> f return type")
]
if 1 <= len(node.args) <= 2:
arg0 = node.args[0].type
if types.is_var(arg0):
pass # undetermined yet
else:
if types.is_method(arg0):
fn = types.get_method_function(arg0)
elif types.is_function(arg0) or types.is_subkernel(arg0):
fn = arg0
else:
diagnose(valid_forms())
self._unify(node.type, fn.ret,
node.loc, None)
if len(node.args) == 2:
arg1 = node.args[1]
if types.is_var(arg1.type):
pass
elif builtins.is_int(arg1.type):
# promote to TInt64
self._unify(arg1.type, builtins.TInt64(),
arg1.loc, None)
else:
diagnose(valid_forms())
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "subkernel_preload"):
valid_forms = lambda: [
valid_form("subkernel_preload(f: subkernel) -> None")
]
if len(node.args) == 1:
arg0 = node.args[0].type
if types.is_var(arg0):
pass # undetermined yet
else:
if types.is_method(arg0):
fn = types.get_method_function(arg0)
elif types.is_function(arg0) or types.is_subkernel(arg0):
fn = arg0
else:
diagnose(valid_forms())
self._unify(node.type, fn.ret,
node.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "subkernel_send"):
valid_forms = lambda: [
valid_form("subkernel_send(dest: numpy.int?, name: str, value: V) -> None"),
]
self._unify(node.type, builtins.TNone(),
node.loc, None)
if len(node.args) == 3:
arg0 = node.args[0]
if types.is_var(arg0.type):
pass # undetermined yet
else:
if builtins.is_int(arg0.type):
self._unify(arg0.type, builtins.TInt8(),
arg0.loc, None)
else:
diagnose(valid_forms())
arg1 = node.args[1]
self._unify(arg1.type, builtins.TStr(),
arg1.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "subkernel_recv"):
valid_forms = lambda: [
valid_form("subkernel_recv(name: str, value_type: type) -> value_type"),
valid_form("subkernel_recv(name: str, value_type: type, timeout: numpy.int64) -> value_type"),
]
if 2 <= len(node.args) <= 3:
arg0 = node.args[0]
if types.is_var(arg0.type):
pass
else:
self._unify(arg0.type, builtins.TStr(),
arg0.loc, None)
arg1 = node.args[1]
if types.is_var(arg1.type):
pass
else:
self._unify(node.type, arg1.value,
node.loc, None)
if len(node.args) == 3:
arg2 = node.args[2]
if types.is_var(arg2.type):
pass
elif builtins.is_int(arg2.type):
# promote to TInt64
self._unify(arg2.type, builtins.TInt64(),
arg2.loc, None)
else:
diagnose(valid_forms())
else:
diagnose(valid_forms())
else:
assert False
def visit_CallT(self, node):
self.generic_visit(node)
for (sigil_loc, vararg) in ((node.star_loc, node.starargs),
(node.dstar_loc, node.kwargs)):
if vararg:
diag = diagnostic.Diagnostic("error",
"variadic arguments are not supported", {},
sigil_loc, [vararg.loc])
self.engine.process(diag)
return
typ = node.func.type.find()
if types.is_var(typ):
return # not enough info yet
elif types.is_builtin(typ):
return self.visit_builtin_call(node)
elif types.is_rpc(typ):
self._unify(node.type, typ.ret,
node.loc, None)
return
elif not (types.is_function(typ) or types.is_method(typ)):
diag = diagnostic.Diagnostic("error",
"cannot call this expression of type {type}",
{"type": types.TypePrinter().name(typ)},
node.func.loc, [])
self.engine.process(diag)
return
if types.is_function(typ):
typ_arity = typ.arity()
typ_args = typ.args
typ_optargs = typ.optargs
typ_ret = typ.ret
typ_func = typ
else:
typ_self = types.get_method_self(typ)
typ_func = types.get_method_function(typ)
if types.is_var(typ_func):
return # not enough info yet
elif types.is_rpc(typ_func):
self._unify(node.type, typ_func.ret,
node.loc, None)
return
elif typ_func.arity() == 0:
return # error elsewhere
method_args = list(typ_func.args.items())
self_arg_name, self_arg_type = method_args[0]
self._unify(self_arg_type, typ_self,
node.loc, None)
typ_arity = typ_func.arity() - 1
typ_args = OrderedDict(method_args[1:])
typ_optargs = typ_func.optargs
typ_ret = typ_func.ret
passed_args = dict()
if len(node.args) > typ_arity:
note = diagnostic.Diagnostic("note",
"extraneous argument(s)", {},
node.args[typ_arity].loc.join(node.args[-1].loc))
diag = diagnostic.Diagnostic("error",
"this function of type {type} accepts at most {num} arguments",
{"type": types.TypePrinter().name(node.func.type),
"num": typ_arity},
node.func.loc, [], [note])
self.engine.process(diag)
return
# Array broadcasting for functions explicitly marked as such.
if len(node.args) == typ_arity and types.is_broadcast_across_arrays(typ):
if typ_arity == 1:
arg_type = node.args[0].type.find()
if builtins.is_array(arg_type):
typ_arg, = typ_args.values()
self._unify(typ_arg, arg_type["elt"], node.args[0].loc, None)
self._unify(node.type, builtins.TArray(typ_ret, arg_type["num_dims"]),
node.loc, None)
return
elif typ_arity == 2:
if any(builtins.is_array(arg.type) for arg in node.args):
ret, arg0, arg1 = self._coerce_binary_broadcast_op(
node.args[0], node.args[1], lambda t: typ_ret, node.loc)
node.args[0] = self._coerce_one(arg0, node.args[0],
other_node=node.args[1])
node.args[1] = self._coerce_one(arg1, node.args[1],
other_node=node.args[0])
self._unify(node.type, ret, node.loc, None)
return
if types.is_subkernel(typ_func) and typ_func.sid not in self.subkernel_arg_types:
self.subkernel_arg_types[typ_func.sid] = []
for actualarg, (formalname, formaltyp) in \
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
self._unify(actualarg.type, formaltyp,
actualarg.loc, None)
passed_args[formalname] = actualarg.loc
if types.is_subkernel(typ_func):
if types.is_instance(actualarg.type):
# objects cannot be passed to subkernels, as rpc code doesn't support them
diag = diagnostic.Diagnostic("error",
"argument '{name}' of type: {typ} is not supported in subkernels",
{"name": formalname, "typ": actualarg.type},
actualarg.loc, [])
self.engine.process(diag)
self.subkernel_arg_types[typ_func.sid].append((formalname, formaltyp))
for keyword in node.keywords:
if keyword.arg in passed_args:
diag = diagnostic.Diagnostic("error",
"the argument '{name}' has been passed earlier as positional",
{"name": keyword.arg},
keyword.arg_loc, [passed_args[keyword.arg]])
self.engine.process(diag)
return
if keyword.arg in typ_args:
self._unify(keyword.value.type, typ_args[keyword.arg],
keyword.value.loc, None)
elif keyword.arg in typ_optargs:
self._unify(keyword.value.type, typ_optargs[keyword.arg],
keyword.value.loc, None)
else:
note = diagnostic.Diagnostic("note",
"extraneous argument", {},
keyword.loc)
diag = diagnostic.Diagnostic("error",
"this function of type {type} does not accept argument '{name}'",
{"type": types.TypePrinter().name(node.func.type),
"name": keyword.arg},
node.func.loc, [], [note])
self.engine.process(diag)
return
passed_args[keyword.arg] = keyword.arg_loc
for formalname in typ_args:
if formalname not in passed_args and not node.remote_fn:
note = diagnostic.Diagnostic("note",
"the called function is of type {type}",
{"type": types.TypePrinter().name(node.func.type)},
node.func.loc)
diag = diagnostic.Diagnostic("error",
"mandatory argument '{name}' is not passed",
{"name": formalname},
node.begin_loc.join(node.end_loc), [], [note])
self.engine.process(diag)
return
self._unify(node.type, typ_ret,
node.loc, None)
def visit_LambdaT(self, node):
self.generic_visit(node)
signature_type = self._type_from_arguments(node.args, node.body.type)
if signature_type:
self._unify(node.type, signature_type,
node.loc, None)
def visit_Assign(self, node):
self.generic_visit(node)
for target in node.targets:
self._unify(target.type, node.value.type,
target.loc, node.value.loc)
def visit_AugAssign(self, node):
self.generic_visit(node)
coerced = self._coerce_binop(node.op, node.target, node.value)
if coerced:
return_type, target_type, value_type = coerced
if isinstance(node.value, asttyped.CoerceT):
orig_value_type = node.value.value.type
else:
orig_value_type = node.value.type
try:
node.target.type.unify(return_type)
except types.UnificationError as e:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"expression of type {typec}",
{"typec": printer.name(orig_value_type)},
node.value.loc)
diag = diagnostic.Diagnostic("error",
"the result of this operation has type {typeb}, "
"which cannot be assigned to a left-hand side of type {typea}",
{"typea": printer.name(node.target.type),
"typeb": printer.name(return_type)},
node.op.loc, [node.target.loc], [note])
self.engine.process(diag)
return
try:
node.target.type.unify(target_type)
except types.UnificationError as e:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"expression of type {typec}",
{"typec": printer.name(orig_value_type)},
node.value.loc)
diag = diagnostic.Diagnostic("error",
"this operation requires the left-hand side of type {typea} "
"to be coerced to {typeb}, which cannot be done",
{"typea": printer.name(node.target.type),
"typeb": printer.name(target_type)},
node.op.loc, [node.target.loc], [note])
self.engine.process(diag)
return
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
def visit_ForT(self, node):
old_in_loop, self.in_loop = self.in_loop, True
self.generic_visit(node)
self.in_loop = old_in_loop
self._unify_iterable(node.target, node.iter)
def visit_While(self, node):
old_in_loop, self.in_loop = self.in_loop, True
self.generic_visit(node)
self.in_loop = old_in_loop
def visit_Break(self, node):
if not self.in_loop:
diag = diagnostic.Diagnostic("error",
"break statement outside of a loop", {},
node.keyword_loc)
self.engine.process(diag)
def visit_Continue(self, node):
if not self.in_loop:
diag = diagnostic.Diagnostic("error",
"continue statement outside of a loop", {},
node.keyword_loc)
self.engine.process(diag)
def visit_withitemT(self, node):
self.generic_visit(node)
typ = node.context_expr.type
if (types.is_builtin(typ, "interleave") or types.is_builtin(typ, "sequential") or
types.is_builtin(typ, "parallel")):
# builtin context managers
if node.optional_vars is not None:
self._unify(node.optional_vars.type, builtins.TNone(),
node.optional_vars.loc, None)
elif types.is_instance(typ) or types.is_constructor(typ):
# user-defined context managers
self._unify_attribute(result_type=node.enter_type, value_node=node.context_expr,
attr_name='__enter__', attr_loc=None, loc=node.loc)
self._unify_attribute(result_type=node.exit_type, value_node=node.context_expr,
attr_name='__exit__', attr_loc=None, loc=node.loc)
printer = types.TypePrinter()
def check_callback(attr_name, typ, arity):
if types.is_var(typ):
return
if not (types.is_method(typ) or types.is_function(typ)):
diag = diagnostic.Diagnostic("error",
"attribute '{attr}' of type {manager_type} must be a function",
{"attr": attr_name,
"manager_type": printer.name(node.context_expr.type)},
node.context_expr.loc)
self.engine.process(diag)
return
if types.is_method(typ):
typ = types.get_method_function(typ).find()
else:
typ = typ.find()
if not (len(typ.args) == arity and len(typ.optargs) == 0):
diag = diagnostic.Diagnostic("error",
"function '{attr}{attr_type}' must accept "
"{arity} positional argument{s} and no optional arguments",
{"attr": attr_name,
"attr_type": printer.name(typ),
"arity": arity, "s": "s" if arity > 1 else ""},
node.context_expr.loc)
self.engine.process(diag)
for formal_arg_name in list(typ.args)[1:]:
formal_arg_type = typ.args[formal_arg_name]
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"exception handling via context managers is not supported; "
"the argument '{arg}' of function '{attr}{attr_type}' "
"will always be None",
{"arg": formal_arg_name,
"attr": attr_name,
"attr_type": printer.name(typ)},
loca),
]
self._unify(formal_arg_type, builtins.TNone(),
node.context_expr.loc, None,
makenotes=makenotes)
check_callback('__enter__', node.enter_type, 1)
check_callback('__exit__', node.exit_type, 4)
if node.optional_vars is not None:
if types.is_method(node.exit_type):
var_type = types.get_method_function(node.exit_type).find().ret
else:
var_type = node.exit_type.find().ret
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca),
diagnostic.Diagnostic("note",
"context manager with an '__enter__' method returning {typeb}",
{"typeb": printer.name(typeb)},
locb)
]
self._unify(node.optional_vars.type, var_type,
node.optional_vars.loc, node.context_expr.loc,
makenotes=makenotes)
elif not types.is_var(typ):
diag = diagnostic.Diagnostic("error",
"value of type {type} cannot act as a context manager",
{"type": types.TypePrinter().name(typ)},
node.context_expr.loc)
self.engine.process(diag)
def visit_With(self, node):
self.generic_visit(node)
for item_node in node.items:
typ = item_node.context_expr.type.find()
if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "interleave") or
types.is_builtin(typ, "sequential")) and len(node.items) != 1:
diag = diagnostic.Diagnostic("error",
"the '{kind}' context manager must be the only one in a 'with' statement",
{"kind": typ.name},
node.keyword_loc.join(node.colon_loc))
self.engine.process(diag)
def visit_ExceptHandlerT(self, node):
self.generic_visit(node)
if node.filter is not None:
if not types.is_exn_constructor(node.filter.type):
diag = diagnostic.Diagnostic("error",
"this expression must refer to an exception constructor",
{"type": types.TypePrinter().name(node.filter.type)},
node.filter.loc)
self.engine.process(diag)
else:
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"expression of type {typea}",
{"typea": printer.name(typea)},
loca),
diagnostic.Diagnostic("note",
"constructor of an exception of type {typeb}",
{"typeb": printer.name(typeb)},
locb)
]
self._unify(node.name_type, node.filter.type.instance,
node.name_loc, node.filter.loc, makenotes)
def _type_from_arguments(self, node, ret):
self.generic_visit(node)
for (sigil_loc, vararg) in ((node.star_loc, node.vararg),
(node.dstar_loc, node.kwarg)):
if vararg:
diag = diagnostic.Diagnostic("error",
"variadic arguments are not supported", {},
sigil_loc, [vararg.loc])
self.engine.process(diag)
return
def extract_args(arg_nodes):
args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes]
return OrderedDict(args)
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
extract_args(node.args[len(node.args) - len(node.defaults):]),
ret)
def visit_arguments(self, node):
self.generic_visit(node)
for arg, default in zip(node.args[len(node.args) - len(node.defaults):], node.defaults):
self._unify(arg.type, default.type,
arg.loc, default.loc)
def visit_FunctionDefT(self, node):
for index, decorator in enumerate(node.decorator_list):
def eval_attr(attr):
if isinstance(attr.value, asttyped.QuoteT):
return getattr(attr.value.value, attr.attr)
return getattr(eval_attr(attr.value), attr.attr)
if isinstance(decorator, asttyped.AttributeT):
decorator = eval_attr(decorator)
if id(decorator) == id(kernel) or \
types.is_builtin(decorator.type, "kernel") or \
isinstance(decorator, asttyped.CallT) and \
types.is_builtin(decorator.func.type, "kernel"):
continue
diag = diagnostic.Diagnostic("error",
"decorators are not supported", {},
node.at_locs[index], [])
self.engine.process(diag)
try:
old_function, self.function = self.function, node
old_in_loop, self.in_loop = self.in_loop, False
old_has_return, self.has_return = self.has_return, False
self.generic_visit(node)
# Lack of return statements is not the only case where the return
# type cannot be inferred. The other one is infinite (possibly mutual)
# recursion. Since Python functions don't have to return a value,
# we ignore that one.
if not self.has_return:
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"function with return type {typea}",
{"typea": printer.name(typea)},
node.name_loc),
]
self._unify(node.return_type, builtins.TNone(),
node.name_loc, None, makenotes)
finally:
self.function = old_function
self.in_loop = old_in_loop
self.has_return = old_has_return
signature_type = self._type_from_arguments(node.args, node.return_type)
if signature_type:
self._unify(node.signature_type, signature_type,
node.name_loc, None)
visit_QuotedFunctionDefT = visit_FunctionDefT
def visit_ClassDefT(self, node):
if any(node.decorator_list):
diag = diagnostic.Diagnostic("error",
"decorators are not supported", {},
node.at_locs[0], [node.decorator_list[0].loc])
self.engine.process(diag)
self.generic_visit(node)
def visit_Return(self, node):
if not self.function:
diag = diagnostic.Diagnostic("error",
"return statement outside of a function", {},
node.keyword_loc)
self.engine.process(diag)
return
self.has_return = True
self.generic_visit(node)
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"function with return type {typea}",
{"typea": printer.name(typea)},
self.function.name_loc),
diagnostic.Diagnostic("note",
"a statement returning {typeb}",
{"typeb": printer.name(typeb)},
node.loc)
]
if node.value is None:
self._unify(self.function.return_type, builtins.TNone(),
self.function.name_loc, node.loc, makenotes)
else:
self._unify(self.function.return_type, node.value.type,
self.function.name_loc, node.value.loc, makenotes)
def visit_Raise(self, node):
self.generic_visit(node)
if node.exc is not None:
exc_type = node.exc.type
if types.is_exn_constructor(exc_type):
pass # short form
elif not types.is_var(exc_type) and not builtins.is_exception(exc_type):
diag = diagnostic.Diagnostic("error",
"cannot raise a value of type {type}, which is not an exception",
{"type": types.TypePrinter().name(exc_type)},
node.loc)
self.engine.process(diag)
def visit_Assert(self, node):
self.generic_visit(node)
self._unify(node.test.type, builtins.TBool(),
node.test.loc, None)
if node.msg is not None:
if not isinstance(node.msg, asttyped.StrT):
diag = diagnostic.Diagnostic("error",
"assertion message must be a string literal", {},
node.msg.loc)
self.engine.process(diag)