2015-07-03 00:49:52 +08:00
|
|
|
"""
|
|
|
|
:class:`Inferencer` performs unification-based inference on a typedtree.
|
|
|
|
"""
|
|
|
|
|
2015-06-15 16:30:50 +08:00
|
|
|
from collections import OrderedDict
|
2015-07-03 00:35:35 +08:00
|
|
|
from pythonparser import algorithm, diagnostic, ast
|
|
|
|
from .. import asttyped, types, builtins
|
2017-04-12 12:10:08 +08:00
|
|
|
from .typedtree_printer import TypedtreePrinter
|
2021-07-02 16:28:47 +08:00
|
|
|
from artiq.experiment import kernel
|
2015-06-13 16:33:15 +08:00
|
|
|
|
2020-07-26 03:09:42 +08:00
|
|
|
|
compiler: Change type inference rules for empty array() calls
array([...]), the constructor for NumPy arrays, currently has the
status of some weird kind of macro in ARTIQ Python, as it needs
to determine the number of dimensions in the resulting array
type, which is a fixed type parameter on which inference cannot
be performed.
This leads to an ambiguity for empty lists, which could contain
elements of arbitrary type, including other lists (which would
add to the number of dimensions).
Previously, I had chosen to make array([]) to be of completely
indeterminate type for this reason. However, this is different
to how the call behaves in host NumPy, where this is a well-formed
call creating an empty 1D array (or 2D for array([[], []]), etc.).
This commit adds special matching for (recursive lists of) empty
ListT AST nodes to treat them as scalar dimensions, with the
element type still unknown.
This also happens to fix type inference for embedding empty 1D
NumPy arrays from host object attributes, although multi-dimensional
arrays will still require work (see GitHub #1633).
GitHub: Fixes #1626.
2021-03-15 06:36:40 +08:00
|
|
|
def is_nested_empty_list(node):
|
|
|
|
"""If the passed AST node is an empty list, or a regularly nested list thereof,
|
|
|
|
returns the number of nesting layers, or ``None`` otherwise.
|
|
|
|
|
|
|
|
For instance, ``is_nested_empty_list([]) == 1`` and
|
|
|
|
``is_nested_empty_list([[], []]) == 2``, but
|
|
|
|
``is_nested_empty_list([[[]], []]) == None`` as the number of nesting layers doesn't
|
|
|
|
match.
|
|
|
|
"""
|
|
|
|
if not isinstance(node, ast.List):
|
|
|
|
return None
|
|
|
|
if not node.elts:
|
|
|
|
return 1
|
|
|
|
result = is_nested_empty_list(node.elts[0])
|
|
|
|
if result is None:
|
|
|
|
return None
|
|
|
|
for elt in node.elts[:1]:
|
|
|
|
if result != is_nested_empty_list(elt):
|
|
|
|
return None
|
|
|
|
return result + 1
|
|
|
|
|
|
|
|
|
2015-06-13 16:03:33 +08:00
|
|
|
class Inferencer(algorithm.Visitor):
|
2015-06-13 18:08:16 +08:00
|
|
|
"""
|
|
|
|
:class:`Inferencer` infers types by recursively applying the unification
|
|
|
|
algorithm. It does not treat inability to infer a concrete type as an error;
|
|
|
|
the result can still contain type variables.
|
|
|
|
|
|
|
|
:class:`Inferencer` is idempotent, but does not guarantee that it will
|
|
|
|
perform all possible inference in a single pass.
|
|
|
|
"""
|
|
|
|
|
2015-06-13 16:03:33 +08:00
|
|
|
def __init__(self, engine):
|
|
|
|
self.engine = engine
|
2015-06-13 17:07:46 +08:00
|
|
|
self.function = None # currently visited function, for Return inference
|
|
|
|
self.in_loop = False
|
2015-07-04 05:58:48 +08:00
|
|
|
self.has_return = False
|
2023-10-05 14:35:50 +08:00
|
|
|
self.subkernel_arg_types = dict()
|
2015-06-13 16:03:33 +08:00
|
|
|
|
2015-08-19 13:39:22 +08:00
|
|
|
def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""):
|
2015-06-13 16:03:33 +08:00
|
|
|
try:
|
|
|
|
typea.unify(typeb)
|
|
|
|
except types.UnificationError as e:
|
|
|
|
printer = types.TypePrinter()
|
|
|
|
|
|
|
|
if makenotes:
|
|
|
|
notes = makenotes(printer, typea, typeb, loca, locb)
|
|
|
|
else:
|
|
|
|
notes = [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
2015-06-15 16:30:50 +08:00
|
|
|
loca)
|
2015-06-13 16:03:33 +08:00
|
|
|
]
|
2015-06-15 16:30:50 +08:00
|
|
|
if locb:
|
|
|
|
notes.append(
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typeb}",
|
|
|
|
{"typeb": printer.name(typeb)},
|
|
|
|
locb))
|
2015-06-13 16:03:33 +08:00
|
|
|
|
|
|
|
highlights = [locb] if locb else []
|
2015-08-19 13:39:22 +08:00
|
|
|
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find() or \
|
|
|
|
e.typeb.find() == typea.find() and e.typea.find() == typeb.find():
|
2015-06-13 16:03:33 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-08-19 13:39:22 +08:00
|
|
|
"cannot unify {typea} with {typeb}{when}",
|
|
|
|
{"typea": printer.name(typea), "typeb": printer.name(typeb),
|
|
|
|
"when": when},
|
2015-06-13 16:03:33 +08:00
|
|
|
loca, highlights, notes)
|
|
|
|
else: # give more detail
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-08-19 13:39:22 +08:00
|
|
|
"cannot unify {typea} with {typeb}{when}: {fraga} is incompatible with {fragb}",
|
2015-06-13 16:03:33 +08:00
|
|
|
{"typea": printer.name(typea), "typeb": printer.name(typeb),
|
2015-08-19 13:39:22 +08:00
|
|
|
"fraga": printer.name(e.typea), "fragb": printer.name(e.typeb),
|
|
|
|
"when": when},
|
2015-06-13 16:03:33 +08:00
|
|
|
loca, highlights, notes)
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
|
|
|
# makenotes for the case where types of multiple elements are unified
|
|
|
|
# with the type of parent expression
|
|
|
|
def _makenotes_elts(self, elts, kind):
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"{kind} of type {typea}",
|
|
|
|
{"kind": kind, "typea": printer.name(elts[0].type)},
|
|
|
|
elts[0].loc),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"{kind} of type {typeb}",
|
|
|
|
{"kind": kind, "typeb": printer.name(typeb)},
|
|
|
|
locb)
|
|
|
|
]
|
|
|
|
return makenotes
|
|
|
|
|
2015-06-11 11:34:22 +08:00
|
|
|
def visit_ListT(self, node):
|
2015-06-14 17:07:13 +08:00
|
|
|
self.generic_visit(node)
|
2015-08-07 18:56:18 +08:00
|
|
|
elt_type_loc = node.loc
|
2015-06-11 11:34:22 +08:00
|
|
|
for elt in node.elts:
|
2015-07-03 00:35:35 +08:00
|
|
|
self._unify(node.type["elt"], elt.type,
|
2015-08-07 18:56:18 +08:00
|
|
|
elt_type_loc, elt.loc,
|
|
|
|
self._makenotes_elts(node.elts, "a list element"))
|
|
|
|
elt_type_loc = elt.loc
|
2015-06-11 11:34:22 +08:00
|
|
|
|
2015-06-13 18:50:56 +08:00
|
|
|
def visit_AttributeT(self, node):
|
2015-06-14 17:07:13 +08:00
|
|
|
self.generic_visit(node)
|
2016-01-04 22:13:05 +08:00
|
|
|
self._unify_attribute(result_type=node.type, value_node=node.value,
|
|
|
|
attr_name=node.attr, attr_loc=node.attr_loc,
|
|
|
|
loc=node.loc)
|
|
|
|
|
2016-05-09 20:25:47 +08:00
|
|
|
def _unify_method_self(self, method_type, attr_name, attr_loc, loc, self_loc):
|
|
|
|
self_type = types.get_method_self(method_type)
|
|
|
|
function_type = types.get_method_function(method_type)
|
|
|
|
|
|
|
|
if len(function_type.args) < 1:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"function '{attr}{type}' of class '{class}' cannot accept a self argument",
|
|
|
|
{"attr": attr_name, "type": types.TypePrinter().name(function_type),
|
|
|
|
"class": self_type.name},
|
|
|
|
loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
else:
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
if attr_loc is None:
|
|
|
|
msgb = "reference to an instance with a method '{attr}{typeb}'"
|
|
|
|
else:
|
|
|
|
msgb = "reference to a method '{attr}{typeb}'"
|
|
|
|
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
loca),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
msgb,
|
|
|
|
{"attr": attr_name,
|
|
|
|
"typeb": printer.name(function_type)},
|
|
|
|
locb)
|
|
|
|
]
|
|
|
|
|
|
|
|
self._unify(self_type, list(function_type.args.values())[0],
|
|
|
|
self_loc, loc,
|
|
|
|
makenotes=makenotes,
|
|
|
|
when=" while inferring the type for self argument")
|
|
|
|
|
2016-01-04 22:13:05 +08:00
|
|
|
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
|
|
|
|
object_type = value_node.type.find()
|
2015-06-13 18:50:56 +08:00
|
|
|
if not types.is_var(object_type):
|
2016-01-04 22:13:05 +08:00
|
|
|
if attr_name in object_type.attributes:
|
2015-08-28 16:23:15 +08:00
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
loca),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typeb}",
|
|
|
|
{"typeb": printer.name(object_type)},
|
2016-01-04 22:13:05 +08:00
|
|
|
value_node.loc)
|
2015-08-28 16:23:15 +08:00
|
|
|
]
|
|
|
|
|
2016-01-04 22:13:05 +08:00
|
|
|
attr_type = object_type.attributes[attr_name]
|
|
|
|
self._unify(result_type, attr_type, loc, None,
|
|
|
|
makenotes=makenotes, when=" for attribute '{}'".format(attr_name))
|
2015-08-15 23:04:12 +08:00
|
|
|
elif types.is_instance(object_type) and \
|
2016-01-04 22:13:05 +08:00
|
|
|
attr_name in object_type.constructor.attributes:
|
|
|
|
attr_type = object_type.constructor.attributes[attr_name].find()
|
2015-08-19 13:39:22 +08:00
|
|
|
if types.is_function(attr_type):
|
|
|
|
# Convert to a method.
|
2015-12-02 21:48:14 +08:00
|
|
|
attr_type = types.TMethod(object_type, attr_type)
|
2016-05-09 20:25:47 +08:00
|
|
|
self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc)
|
2023-10-05 14:35:50 +08:00
|
|
|
elif types.is_rpc(attr_type) or types.is_subkernel(attr_type):
|
2016-04-26 06:05:32 +08:00
|
|
|
# Convert to a method. We don't have to bother typechecking
|
|
|
|
# the self argument, since for RPCs anything goes.
|
|
|
|
attr_type = types.TMethod(object_type, attr_type)
|
2015-08-28 08:46:50 +08:00
|
|
|
|
|
|
|
if not types.is_var(attr_type):
|
2016-01-04 22:13:05 +08:00
|
|
|
self._unify(result_type, attr_type,
|
|
|
|
loc, None)
|
2015-06-13 18:50:56 +08:00
|
|
|
else:
|
2016-01-05 00:11:03 +08:00
|
|
|
if attr_loc.source_buffer == value_node.loc.source_buffer:
|
2016-01-04 22:13:05 +08:00
|
|
|
highlights, notes = [value_node.loc], []
|
2015-08-27 18:01:04 +08:00
|
|
|
else:
|
|
|
|
# This happens when the object being accessed is embedded
|
|
|
|
# from the host program.
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"object being accessed", {},
|
2016-01-04 22:13:05 +08:00
|
|
|
value_node.loc)
|
2015-08-27 18:01:04 +08:00
|
|
|
highlights, notes = [], [note]
|
|
|
|
|
2015-06-13 18:50:56 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"type {type} does not have an attribute '{attr}'",
|
2016-01-04 22:13:05 +08:00
|
|
|
{"type": types.TypePrinter().name(object_type), "attr": attr_name},
|
2016-01-05 00:11:03 +08:00
|
|
|
attr_loc, highlights, notes)
|
2015-06-13 18:50:56 +08:00
|
|
|
self.engine.process(diag)
|
|
|
|
|
2015-06-26 23:53:20 +08:00
|
|
|
def _unify_iterable(self, element, collection):
|
2017-06-09 14:59:30 +08:00
|
|
|
if builtins.is_bytes(collection.type) or builtins.is_bytearray(collection.type):
|
|
|
|
self._unify(element.type, builtins.get_iterable_elt(collection.type),
|
2017-06-09 14:50:48 +08:00
|
|
|
element.loc, None)
|
2020-07-26 08:33:52 +08:00
|
|
|
elif builtins.is_array(collection.type):
|
|
|
|
array_type = collection.type.find()
|
|
|
|
elem_dims = array_type["num_dims"].value - 1
|
|
|
|
if elem_dims > 0:
|
|
|
|
elem_type = builtins.TArray(array_type["elt"], types.TValue(elem_dims))
|
|
|
|
else:
|
|
|
|
elem_type = array_type["elt"]
|
|
|
|
self._unify(element.type, elem_type, element.loc, collection.loc)
|
2017-06-09 15:29:08 +08:00
|
|
|
elif builtins.is_iterable(collection.type) and not builtins.is_str(collection.type):
|
2015-06-26 23:53:20 +08:00
|
|
|
rhs_type = collection.type.find()
|
|
|
|
rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type})
|
|
|
|
self._unify(rhs_wrapped_lhs_type, rhs_type,
|
|
|
|
element.loc, collection.loc)
|
|
|
|
elif not types.is_var(collection.type):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"type {type} is not iterable",
|
|
|
|
{"type": types.TypePrinter().name(collection.type)},
|
|
|
|
collection.loc, [])
|
|
|
|
self.engine.process(diag)
|
2015-06-15 03:48:04 +08:00
|
|
|
|
2015-07-16 09:21:21 +08:00
|
|
|
def visit_Index(self, node):
|
2015-07-22 23:34:52 +08:00
|
|
|
self.generic_visit(node)
|
2015-07-16 09:21:21 +08:00
|
|
|
value = node.value
|
|
|
|
if types.is_tuple(value.type):
|
2020-08-09 10:28:36 +08:00
|
|
|
for elt in value.type.find().elts:
|
|
|
|
self._unify(elt, builtins.TInt(),
|
|
|
|
value.loc, None)
|
2015-07-16 09:21:21 +08:00
|
|
|
else:
|
|
|
|
self._unify(value.type, builtins.TInt(),
|
|
|
|
value.loc, None)
|
|
|
|
|
2015-07-16 19:53:24 +08:00
|
|
|
def visit_SliceT(self, node):
|
2021-03-15 02:46:28 +08:00
|
|
|
self.generic_visit(node)
|
2016-01-16 11:09:03 +08:00
|
|
|
if (node.lower, node.upper, node.step) == (None, None, None):
|
|
|
|
self._unify(node.type, builtins.TInt32(),
|
|
|
|
node.loc, None)
|
|
|
|
else:
|
|
|
|
self._unify(node.type, builtins.TInt(),
|
|
|
|
node.loc, None)
|
|
|
|
for operand in (node.lower, node.upper, node.step):
|
|
|
|
if operand is not None:
|
|
|
|
self._unify(operand.type, node.type,
|
|
|
|
operand.loc, None)
|
2015-07-16 09:21:21 +08:00
|
|
|
|
|
|
|
def visit_ExtSlice(self, node):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"multi-dimensional slices are not supported", {},
|
|
|
|
node.loc, [])
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
2015-06-11 11:34:22 +08:00
|
|
|
def visit_SubscriptT(self, node):
|
2015-06-14 17:07:13 +08:00
|
|
|
self.generic_visit(node)
|
2023-09-12 21:43:38 +08:00
|
|
|
|
|
|
|
if types.is_tuple(node.value.type):
|
|
|
|
if (not isinstance(node.slice, ast.Index) or
|
|
|
|
not isinstance(node.slice.value, ast.Num)):
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error", "tuples can only be indexed by a constant", {},
|
|
|
|
node.slice.loc, []
|
|
|
|
)
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
tuple_type = node.value.type.find()
|
|
|
|
index = node.slice.value.n
|
|
|
|
if index < 0 or index >= len(tuple_type.elts):
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"index {index} is out of range for tuple of size {size}",
|
|
|
|
{"index": index, "size": len(tuple_type.elts)},
|
|
|
|
node.slice.loc, []
|
|
|
|
)
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
self._unify(node.type, tuple_type.elts[index], node.loc, node.value.loc)
|
|
|
|
elif isinstance(node.slice, ast.Index):
|
2020-08-09 10:28:36 +08:00
|
|
|
if types.is_tuple(node.slice.value.type):
|
2020-10-20 07:28:01 +08:00
|
|
|
if types.is_var(node.value.type):
|
|
|
|
return
|
2020-08-09 10:28:36 +08:00
|
|
|
if not builtins.is_array(node.value.type):
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
2020-08-10 06:14:56 +08:00
|
|
|
"multi-dimensional indexing only supported for arrays, not {type}",
|
2020-08-09 10:28:36 +08:00
|
|
|
{"type": types.TypePrinter().name(node.value.type)},
|
|
|
|
node.loc, [])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
num_idxs = len(node.slice.value.type.find().elts)
|
|
|
|
array_type = node.value.type.find()
|
|
|
|
num_dims = array_type["num_dims"].value
|
|
|
|
remaining_dims = num_dims - num_idxs
|
|
|
|
if remaining_dims < 0:
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"too many indices for array of dimension {num_dims}",
|
|
|
|
{"num_dims": num_dims}, node.slice.loc, [])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
if remaining_dims == 0:
|
|
|
|
self._unify(node.type, array_type["elt"], node.loc,
|
|
|
|
node.value.loc)
|
|
|
|
else:
|
|
|
|
self._unify(
|
|
|
|
node.type,
|
|
|
|
builtins.TArray(array_type["elt"], remaining_dims))
|
|
|
|
else:
|
|
|
|
self._unify_iterable(element=node, collection=node.value)
|
2015-07-16 09:21:21 +08:00
|
|
|
elif isinstance(node.slice, ast.Slice):
|
2021-03-15 03:57:01 +08:00
|
|
|
if builtins.is_array(node.value.type):
|
|
|
|
if node.slice.step is not None:
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"strided slicing not yet supported for NumPy arrays", {},
|
|
|
|
node.slice.step.loc, [])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
2020-08-09 10:28:36 +08:00
|
|
|
self._unify(node.type, node.value.type, node.loc, node.value.loc)
|
|
|
|
else: # ExtSlice
|
|
|
|
pass # error emitted above
|
2015-06-11 11:34:22 +08:00
|
|
|
|
|
|
|
def visit_IfExpT(self, node):
|
2015-06-14 17:07:13 +08:00
|
|
|
self.generic_visit(node)
|
2021-03-20 08:27:25 +08:00
|
|
|
self._unify(node.test.type, builtins.TBool(), node.test.loc, None)
|
2015-06-11 11:34:22 +08:00
|
|
|
self._unify(node.body.type, node.orelse.type,
|
|
|
|
node.body.loc, node.orelse.loc)
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(node.type, node.body.type,
|
|
|
|
node.loc, None)
|
2015-06-11 11:34:22 +08:00
|
|
|
|
|
|
|
def visit_BoolOpT(self, node):
|
2015-06-14 17:07:13 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-11 09:20:33 +08:00
|
|
|
for value in node.values:
|
2015-07-22 23:35:18 +08:00
|
|
|
self._unify(node.type, value.type,
|
|
|
|
node.loc, value.loc, self._makenotes_elts(node.values, "an operand"))
|
2015-06-11 08:22:20 +08:00
|
|
|
|
2015-06-12 13:59:41 +08:00
|
|
|
def visit_UnaryOpT(self, node):
|
2015-06-14 17:07:13 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-13 18:45:09 +08:00
|
|
|
operand_type = node.operand.type.find()
|
2015-06-12 13:59:41 +08:00
|
|
|
if isinstance(node.op, ast.Not):
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(node.type, builtins.TBool(),
|
|
|
|
node.loc, None)
|
2015-06-13 18:45:09 +08:00
|
|
|
elif isinstance(node.op, ast.Invert):
|
|
|
|
if builtins.is_int(operand_type):
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(node.type, operand_type,
|
|
|
|
node.loc, None)
|
2015-06-12 13:59:41 +08:00
|
|
|
elif not types.is_var(operand_type):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-06-14 17:07:13 +08:00
|
|
|
"expected '~' operand to be of integer type, not {type}",
|
2015-06-12 13:59:41 +08:00
|
|
|
{"type": types.TypePrinter().name(operand_type)},
|
|
|
|
node.operand.loc)
|
|
|
|
self.engine.process(diag)
|
2015-06-13 18:45:09 +08:00
|
|
|
else: # UAdd, USub
|
2020-07-30 07:09:12 +08:00
|
|
|
if types.is_var(operand_type):
|
|
|
|
return
|
|
|
|
|
2015-06-13 18:45:09 +08:00
|
|
|
if builtins.is_numeric(operand_type):
|
2020-07-30 07:09:12 +08:00
|
|
|
self._unify(node.type, operand_type, node.loc, None)
|
|
|
|
return
|
|
|
|
|
|
|
|
if builtins.is_array(operand_type):
|
|
|
|
elt = operand_type.find()["elt"]
|
|
|
|
if builtins.is_numeric(elt):
|
|
|
|
self._unify(node.type, operand_type, node.loc, None)
|
|
|
|
return
|
|
|
|
if types.is_var(elt):
|
|
|
|
return
|
|
|
|
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"expected unary '{op}' operand to be of numeric type, not {type}",
|
|
|
|
{"op": node.op.loc.source(),
|
|
|
|
"type": types.TypePrinter().name(operand_type)},
|
|
|
|
node.operand.loc)
|
|
|
|
self.engine.process(diag)
|
2015-06-13 14:28:40 +08:00
|
|
|
|
2015-06-14 17:07:13 +08:00
|
|
|
def visit_CoerceT(self, node):
|
|
|
|
self.generic_visit(node)
|
2015-07-16 19:55:23 +08:00
|
|
|
if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type):
|
2015-06-14 17:07:13 +08:00
|
|
|
pass
|
2020-07-29 05:44:27 +08:00
|
|
|
elif (builtins.is_array(node.type) and builtins.is_array(node.value.type)
|
|
|
|
and builtins.is_numeric(node.type.find()["elt"])
|
|
|
|
and builtins.is_numeric(node.value.type.find()["elt"])):
|
|
|
|
pass
|
2015-06-14 17:07:13 +08:00
|
|
|
else:
|
|
|
|
printer = types.TypePrinter()
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"expression that required coercion to {typeb}",
|
|
|
|
{"typeb": printer.name(node.type)},
|
2015-07-16 19:55:23 +08:00
|
|
|
node.other_value.loc)
|
2015-06-14 17:07:13 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"cannot coerce {typea} to {typeb}",
|
2015-07-16 19:55:23 +08:00
|
|
|
{"typea": printer.name(node.value.type), "typeb": printer.name(node.type)},
|
2015-06-14 17:07:13 +08:00
|
|
|
node.loc, notes=[note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
|
|
|
def _coerce_one(self, typ, coerced_node, other_node):
|
|
|
|
if coerced_node.type.find() == typ.find():
|
|
|
|
return coerced_node
|
2015-06-15 03:48:04 +08:00
|
|
|
elif isinstance(coerced_node, asttyped.CoerceT):
|
2015-07-03 00:49:52 +08:00
|
|
|
node = coerced_node
|
2017-03-03 00:15:37 +08:00
|
|
|
node.type.unify(typ)
|
|
|
|
node.other_value = other_node
|
2015-06-14 17:07:13 +08:00
|
|
|
else:
|
2015-07-16 19:55:23 +08:00
|
|
|
node = asttyped.CoerceT(type=typ, value=coerced_node, other_value=other_node,
|
2015-06-14 17:07:13 +08:00
|
|
|
loc=coerced_node.loc)
|
2015-06-15 03:48:04 +08:00
|
|
|
self.visit(node)
|
|
|
|
return node
|
2015-06-14 17:07:13 +08:00
|
|
|
|
2020-07-29 05:44:27 +08:00
|
|
|
def _coerce_numeric(self, nodes, map_return=lambda typ: typ, map_node_type =lambda typ:typ):
|
2015-06-14 17:07:13 +08:00
|
|
|
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
|
2015-07-14 11:42:09 +08:00
|
|
|
node_types = []
|
|
|
|
for node in nodes:
|
|
|
|
if isinstance(node, asttyped.CoerceT):
|
2018-04-22 02:24:00 +08:00
|
|
|
# If we already know exactly what we coerce this value to, use that type,
|
|
|
|
# or we'll get an unification error in case the coerced type is not the same
|
|
|
|
# as the type of the coerced value.
|
|
|
|
# Otherwise, use the potentially more specific subtype when considering possible
|
|
|
|
# coercions, or we may get stuck.
|
|
|
|
if node.type.fold(False, lambda acc, ty: acc or types.is_var(ty)):
|
|
|
|
node_types.append(node.value.type)
|
|
|
|
else:
|
|
|
|
node_types.append(node.type)
|
2015-07-14 11:42:09 +08:00
|
|
|
else:
|
|
|
|
node_types.append(node.type)
|
2020-07-29 05:44:27 +08:00
|
|
|
node_types = [map_node_type(typ) for typ in node_types]
|
2015-06-15 03:48:04 +08:00
|
|
|
if any(map(types.is_var, node_types)): # not enough info yet
|
2015-06-14 18:10:32 +08:00
|
|
|
return
|
2015-06-15 03:48:04 +08:00
|
|
|
elif not all(map(builtins.is_numeric, node_types)):
|
|
|
|
err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes))
|
2015-06-14 17:07:13 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-06-15 03:48:04 +08:00
|
|
|
"cannot coerce {type} to a numeric type",
|
|
|
|
{"type": types.TypePrinter().name(err_node.type)},
|
|
|
|
err_node.loc, [])
|
2015-06-14 17:07:13 +08:00
|
|
|
self.engine.process(diag)
|
2015-06-14 18:10:32 +08:00
|
|
|
return
|
2015-06-15 03:48:04 +08:00
|
|
|
elif any(map(builtins.is_float, node_types)):
|
|
|
|
typ = builtins.TFloat()
|
|
|
|
elif any(map(builtins.is_int, node_types)):
|
2015-07-03 00:49:52 +08:00
|
|
|
widths = list(map(builtins.get_int_width, node_types))
|
2015-06-15 03:48:04 +08:00
|
|
|
if all(widths):
|
|
|
|
typ = builtins.TInt(types.TValue(max(widths)))
|
|
|
|
else:
|
|
|
|
typ = builtins.TInt()
|
|
|
|
else:
|
|
|
|
assert False
|
2015-06-14 17:07:13 +08:00
|
|
|
|
2015-06-15 03:48:04 +08:00
|
|
|
return map_return(typ)
|
2015-06-14 17:07:13 +08:00
|
|
|
|
|
|
|
def _order_by_pred(self, pred, left, right):
|
|
|
|
if pred(left.type):
|
|
|
|
return left, right
|
|
|
|
elif pred(right.type):
|
|
|
|
return right, left
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
2020-11-11 05:24:04 +08:00
|
|
|
def _coerce_binary_broadcast_op(self, left, right, map_return_elt, op_loc):
|
|
|
|
def num_dims(typ):
|
|
|
|
if builtins.is_array(typ):
|
|
|
|
# TODO: If number of dimensions is ever made a non-fixed parameter,
|
|
|
|
# need to acutally unify num_dims in _coerce_binop/….
|
|
|
|
return typ.find()["num_dims"].value
|
|
|
|
return 0
|
|
|
|
|
|
|
|
left_dims = num_dims(left.type)
|
|
|
|
right_dims = num_dims(right.type)
|
|
|
|
if left_dims != right_dims and left_dims != 0 and right_dims != 0:
|
|
|
|
# Mismatch (only scalar broadcast supported for now).
|
|
|
|
note1 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
|
|
|
{"num_dims": left_dims}, left.loc)
|
|
|
|
note2 = diagnostic.Diagnostic("note", "operand of dimension {num_dims}",
|
|
|
|
{"num_dims": right_dims}, right.loc)
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error", "dimensions of '{op}' array operands must match",
|
|
|
|
{"op": op_loc.source()}, op_loc, [left.loc, right.loc], [note1, note2])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
def map_node_type(typ):
|
|
|
|
if not builtins.is_array(typ):
|
|
|
|
# This is a single value broadcast across the array.
|
|
|
|
return typ
|
|
|
|
return typ.find()["elt"]
|
|
|
|
|
|
|
|
# Figure out result type, handling broadcasts.
|
|
|
|
result_dims = left_dims if left_dims else right_dims
|
|
|
|
def map_return(typ):
|
|
|
|
elt = map_return_elt(typ)
|
|
|
|
result = builtins.TArray(elt=elt, num_dims=result_dims)
|
|
|
|
left = builtins.TArray(elt=elt, num_dims=left_dims) if left_dims else elt
|
|
|
|
right = builtins.TArray(elt=elt, num_dims=right_dims) if right_dims else elt
|
|
|
|
return (result, left, right)
|
|
|
|
|
|
|
|
return self._coerce_numeric((left, right),
|
|
|
|
map_return=map_return,
|
|
|
|
map_node_type=map_node_type)
|
|
|
|
|
2015-06-14 18:10:32 +08:00
|
|
|
def _coerce_binop(self, op, left, right):
|
2020-08-03 00:52:15 +08:00
|
|
|
if isinstance(op, ast.MatMult):
|
|
|
|
if types.is_var(left.type) or types.is_var(right.type):
|
|
|
|
return
|
|
|
|
|
|
|
|
def num_dims(operand):
|
|
|
|
if not builtins.is_array(operand.type):
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"expected matrix multiplication operand to be of array type, not {type}",
|
|
|
|
{
|
|
|
|
"op": op.loc.source(),
|
|
|
|
"type": types.TypePrinter().name(operand.type)
|
|
|
|
}, op.loc, [operand.loc])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
num_dims = operand.type.find()["num_dims"].value
|
|
|
|
if num_dims not in (1, 2):
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"expected matrix multiplication operand to be 1- or 2-dimensional, not {type}",
|
|
|
|
{
|
|
|
|
"op": op.loc.source(),
|
|
|
|
"type": types.TypePrinter().name(operand.type)
|
|
|
|
}, op.loc, [operand.loc])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
return num_dims
|
|
|
|
|
|
|
|
left_dims = num_dims(left)
|
|
|
|
if not left_dims:
|
|
|
|
return
|
|
|
|
right_dims = num_dims(right)
|
|
|
|
if not right_dims:
|
|
|
|
return
|
|
|
|
|
|
|
|
def map_node_type(typ):
|
|
|
|
return typ.find()["elt"]
|
|
|
|
|
|
|
|
def map_return(typ):
|
|
|
|
if left_dims == 1:
|
|
|
|
if right_dims == 1:
|
|
|
|
result_dims = 0
|
|
|
|
else:
|
|
|
|
result_dims = 1
|
|
|
|
elif right_dims == 1:
|
|
|
|
result_dims = 1
|
|
|
|
else:
|
|
|
|
result_dims = 2
|
|
|
|
result = typ if result_dims == 0 else builtins.TArray(
|
|
|
|
typ, result_dims)
|
|
|
|
return (result, builtins.TArray(typ, left_dims),
|
|
|
|
builtins.TArray(typ, right_dims))
|
|
|
|
|
|
|
|
return self._coerce_numeric((left, right),
|
|
|
|
map_return=map_return,
|
|
|
|
map_node_type=map_node_type)
|
|
|
|
elif builtins.is_array(left.type) or builtins.is_array(right.type):
|
2020-07-29 05:44:27 +08:00
|
|
|
# Operations on arrays are element-wise (possibly using broadcasting).
|
|
|
|
|
2020-08-03 00:52:15 +08:00
|
|
|
# TODO: Allow only for integer arrays.
|
2020-07-29 05:44:27 +08:00
|
|
|
# allowed_int_array_ops = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift,
|
|
|
|
# ast.RShift)
|
|
|
|
allowed_array_ops = (ast.Add, ast.Mult, ast.FloorDiv, ast.Mod,
|
|
|
|
ast.Pow, ast.Sub, ast.Div)
|
|
|
|
if not isinstance(op, allowed_array_ops):
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error", "operator '{op}' not valid for array types",
|
|
|
|
{"op": op.loc.source()}, op.loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2020-11-11 05:24:04 +08:00
|
|
|
def map_result(typ):
|
|
|
|
if isinstance(op, ast.Div):
|
|
|
|
return builtins.TFloat()
|
|
|
|
return typ
|
|
|
|
return self._coerce_binary_broadcast_op(left, right, map_result, op.loc)
|
2020-07-29 05:44:27 +08:00
|
|
|
elif isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
|
2015-06-15 03:48:04 +08:00
|
|
|
ast.LShift, ast.RShift)):
|
2015-06-14 17:07:13 +08:00
|
|
|
# bitwise operators require integers
|
2015-06-14 18:10:32 +08:00
|
|
|
for operand in (left, right):
|
2015-06-14 17:07:13 +08:00
|
|
|
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"expected '{op}' operand to be of integer type, not {type}",
|
2015-06-14 18:10:32 +08:00
|
|
|
{"op": op.loc.source(),
|
2015-06-14 17:07:13 +08:00
|
|
|
"type": types.TypePrinter().name(operand.type)},
|
2015-06-14 18:10:32 +08:00
|
|
|
op.loc, [operand.loc])
|
2015-06-14 17:07:13 +08:00
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-06-15 03:48:04 +08:00
|
|
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
2015-06-14 18:10:32 +08:00
|
|
|
elif isinstance(op, ast.Add):
|
2015-06-14 17:07:13 +08:00
|
|
|
# add works on numbers and also collections
|
2015-06-14 18:10:32 +08:00
|
|
|
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
|
2015-06-14 17:07:13 +08:00
|
|
|
collection, other = \
|
2015-06-14 18:10:32 +08:00
|
|
|
self._order_by_pred(builtins.is_collection, left, right)
|
2015-06-14 17:07:13 +08:00
|
|
|
if types.is_tuple(collection.type):
|
|
|
|
pred, kind = types.is_tuple, "tuple"
|
|
|
|
elif builtins.is_list(collection.type):
|
|
|
|
pred, kind = builtins.is_list, "list"
|
|
|
|
else:
|
|
|
|
assert False
|
2015-12-30 15:46:54 +08:00
|
|
|
|
|
|
|
if types.is_var(other.type):
|
|
|
|
return
|
|
|
|
|
2015-06-14 17:07:13 +08:00
|
|
|
if not pred(other.type):
|
|
|
|
printer = types.TypePrinter()
|
|
|
|
note1 = diagnostic.Diagnostic("note",
|
|
|
|
"{kind} of type {typea}",
|
|
|
|
{"typea": printer.name(collection.type), "kind": kind},
|
|
|
|
collection.loc)
|
|
|
|
note2 = diagnostic.Diagnostic("note",
|
|
|
|
"{typeb}, which cannot be added to a {kind}",
|
|
|
|
{"typeb": printer.name(other.type), "kind": kind},
|
|
|
|
other.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"expected every '+' operand to be a {kind} in this context",
|
|
|
|
{"kind": kind},
|
2015-06-14 18:10:32 +08:00
|
|
|
op.loc, [other.loc, collection.loc],
|
2015-06-14 17:07:13 +08:00
|
|
|
[note1, note2])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
if types.is_tuple(collection.type):
|
2015-06-14 18:10:32 +08:00
|
|
|
return types.TTuple(left.type.find().elts +
|
|
|
|
right.type.find().elts), left.type, right.type
|
2015-06-14 17:07:13 +08:00
|
|
|
elif builtins.is_list(collection.type):
|
2015-06-14 18:10:32 +08:00
|
|
|
self._unify(left.type, right.type,
|
|
|
|
left.loc, right.loc)
|
|
|
|
return left.type, left.type, right.type
|
2017-06-09 14:00:57 +08:00
|
|
|
elif (builtins.is_str(left.type) or builtins.is_str(right.type) or
|
|
|
|
builtins.is_bytes(left.type) or builtins.is_bytes(right.type)):
|
2016-08-08 12:05:52 +08:00
|
|
|
self._unify(left.type, right.type,
|
|
|
|
left.loc, right.loc)
|
|
|
|
return left.type, left.type, right.type
|
2015-06-14 17:07:13 +08:00
|
|
|
else:
|
2015-06-15 03:48:04 +08:00
|
|
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
2015-06-14 18:10:32 +08:00
|
|
|
elif isinstance(op, ast.Mult):
|
2015-06-14 17:07:13 +08:00
|
|
|
# mult works on numbers and also number & collection
|
2015-06-14 18:10:32 +08:00
|
|
|
if types.is_tuple(left.type) or types.is_tuple(right.type):
|
|
|
|
tuple_, other = self._order_by_pred(types.is_tuple, left, right)
|
2015-06-14 17:07:13 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-06-30 01:12:09 +08:00
|
|
|
"passing tuples to '*' is not supported", {},
|
2015-06-14 18:10:32 +08:00
|
|
|
op.loc, [tuple_.loc])
|
2015-06-14 17:07:13 +08:00
|
|
|
self.engine.process(diag)
|
2015-06-14 18:10:32 +08:00
|
|
|
return
|
|
|
|
elif builtins.is_list(left.type) or builtins.is_list(right.type):
|
|
|
|
list_, other = self._order_by_pred(builtins.is_list, left, right)
|
2016-06-16 21:32:14 +08:00
|
|
|
if not builtins.is_int(other.type) and not types.is_var(other.type):
|
2015-06-14 17:07:13 +08:00
|
|
|
printer = types.TypePrinter()
|
|
|
|
note1 = diagnostic.Diagnostic("note",
|
|
|
|
"list operand of type {typea}",
|
|
|
|
{"typea": printer.name(list_.type)},
|
|
|
|
list_.loc)
|
|
|
|
note2 = diagnostic.Diagnostic("note",
|
|
|
|
"operand of type {typeb}, which is not a valid repetition amount",
|
|
|
|
{"typeb": printer.name(other.type)},
|
|
|
|
other.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"expected '*' operands to be a list and an integer in this context", {},
|
2015-06-14 18:10:32 +08:00
|
|
|
op.loc, [list_.loc, other.loc],
|
2015-06-14 17:07:13 +08:00
|
|
|
[note1, note2])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
2015-06-14 18:10:32 +08:00
|
|
|
|
|
|
|
return list_.type, left.type, right.type
|
2015-06-14 17:07:13 +08:00
|
|
|
else:
|
2015-06-15 03:48:04 +08:00
|
|
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
2015-07-21 09:54:34 +08:00
|
|
|
elif isinstance(op, (ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
|
2015-06-14 17:07:13 +08:00
|
|
|
# numeric operators work on any kind of number
|
2015-06-15 03:48:04 +08:00
|
|
|
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
|
2015-07-21 09:54:34 +08:00
|
|
|
elif isinstance(op, ast.Div):
|
|
|
|
# division always returns a float
|
2015-07-22 23:34:52 +08:00
|
|
|
return self._coerce_numeric((left, right),
|
|
|
|
lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat()))
|
2020-08-03 00:52:15 +08:00
|
|
|
else:
|
2015-06-14 17:07:13 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-06-14 18:10:32 +08:00
|
|
|
"operator '{op}' is not supported", {"op": op.loc.source()},
|
|
|
|
op.loc)
|
2015-06-14 17:07:13 +08:00
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-06-14 18:10:32 +08:00
|
|
|
def visit_BinOpT(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
coerced = self._coerce_binop(node.op, node.left, node.right)
|
|
|
|
if coerced:
|
|
|
|
return_type, left_type, right_type = coerced
|
|
|
|
node.left = self._coerce_one(left_type, node.left, other_node=node.right)
|
|
|
|
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
|
2016-03-18 09:22:01 +08:00
|
|
|
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
def makenote(typ, coerced, loc):
|
|
|
|
if typ == coerced:
|
|
|
|
return diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {type}",
|
|
|
|
{"type": printer.name(typ)},
|
|
|
|
loc)
|
|
|
|
else:
|
|
|
|
return diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea} (coerced to {typeb})",
|
|
|
|
{"typea": printer.name(typ),
|
|
|
|
"typeb": printer.name(coerced)},
|
|
|
|
loc)
|
|
|
|
|
|
|
|
if node.type == return_type:
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {type}",
|
|
|
|
{"type": printer.name(typea)},
|
|
|
|
loca)
|
|
|
|
else:
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea} (but {typeb} was expected)",
|
|
|
|
{"typea": printer.name(typea),
|
|
|
|
"typeb": printer.name(typeb)},
|
|
|
|
loca)
|
|
|
|
|
|
|
|
return [
|
|
|
|
makenote(node.left.type, left_type, node.left.loc),
|
|
|
|
makenote(node.right.type, right_type, node.right.loc),
|
|
|
|
note
|
|
|
|
]
|
|
|
|
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(node.type, return_type,
|
2016-03-18 09:22:01 +08:00
|
|
|
node.loc, None,
|
|
|
|
makenotes=makenotes)
|
2015-06-14 18:10:32 +08:00
|
|
|
|
2015-06-15 03:48:04 +08:00
|
|
|
def visit_CompareT(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
pairs = zip([node.left] + node.comparators, node.comparators)
|
|
|
|
if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)):
|
|
|
|
for left, right in pairs:
|
|
|
|
self._unify(left.type, right.type,
|
|
|
|
left.loc, right.loc)
|
|
|
|
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
|
|
|
|
for left, right in pairs:
|
2015-06-26 23:53:20 +08:00
|
|
|
self._unify_iterable(element=left, collection=right)
|
2015-06-15 03:48:04 +08:00
|
|
|
else: # Eq, NotEq, Lt, LtE, Gt, GtE
|
|
|
|
operands = [node.left] + node.comparators
|
|
|
|
operand_types = [operand.type for operand in operands]
|
|
|
|
if any(map(builtins.is_collection, operand_types)):
|
|
|
|
for left, right in pairs:
|
|
|
|
self._unify(left.type, right.type,
|
|
|
|
left.loc, right.loc)
|
2015-07-22 23:34:52 +08:00
|
|
|
elif any(map(builtins.is_numeric, operand_types)):
|
2015-06-15 03:48:04 +08:00
|
|
|
typ = self._coerce_numeric(operands)
|
|
|
|
if typ:
|
|
|
|
try:
|
|
|
|
other_node = next(filter(lambda operand: operand.type.find() == typ.find(),
|
|
|
|
operands))
|
|
|
|
except StopIteration:
|
|
|
|
# can't find an argument with an exact type, meaning
|
|
|
|
# the return value is more generic than any of the inputs, meaning
|
|
|
|
# the type is known (typ is not None), but its width is not
|
|
|
|
def wide_enough(opreand):
|
|
|
|
return types.is_mono(opreand.type) and \
|
|
|
|
opreand.type.find().name == typ.find().name
|
|
|
|
other_node = next(filter(wide_enough, operands))
|
|
|
|
node.left, *node.comparators = \
|
|
|
|
[self._coerce_one(typ, operand, other_node) for operand in operands]
|
2015-07-22 23:34:52 +08:00
|
|
|
else:
|
|
|
|
pass # No coercion required.
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(node.type, builtins.TBool(),
|
|
|
|
node.loc, None)
|
2015-06-15 03:48:04 +08:00
|
|
|
|
2015-06-15 13:40:37 +08:00
|
|
|
def visit_ListCompT(self, node):
|
2015-07-16 19:52:41 +08:00
|
|
|
if len(node.generators) > 1:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"multiple for clauses in comprehensions are not supported", {},
|
|
|
|
node.generators[1].for_loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
2015-06-15 13:40:37 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(node.type, builtins.TList(node.elt.type),
|
|
|
|
node.loc, None)
|
2015-06-15 13:40:37 +08:00
|
|
|
|
|
|
|
def visit_comprehension(self, node):
|
2015-07-16 19:52:41 +08:00
|
|
|
if any(node.ifs):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"if clauses in comprehensions are not supported", {},
|
|
|
|
node.if_locs[0])
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
2015-06-15 13:40:37 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-26 23:53:20 +08:00
|
|
|
self._unify_iterable(element=node.target, collection=node.iter)
|
2015-06-15 13:40:37 +08:00
|
|
|
|
2015-06-24 17:16:17 +08:00
|
|
|
def visit_builtin_call(self, node):
|
2015-06-26 16:16:08 +08:00
|
|
|
typ = node.func.type.find()
|
2015-06-24 17:16:17 +08:00
|
|
|
|
|
|
|
def valid_form(signature):
|
|
|
|
return diagnostic.Diagnostic("note",
|
|
|
|
"{func} can be invoked as: {signature}",
|
2015-06-26 16:16:08 +08:00
|
|
|
{"func": typ.name, "signature": signature},
|
2015-06-24 17:16:17 +08:00
|
|
|
node.func.loc)
|
|
|
|
|
|
|
|
def diagnose(valid_forms):
|
2015-07-03 01:06:07 +08:00
|
|
|
printer = types.TypePrinter()
|
|
|
|
args = [printer.name(arg.type) for arg in node.args]
|
|
|
|
args += ["%s=%s" % (kw.arg, printer.name(kw.value.type)) for kw in node.keywords]
|
|
|
|
|
2015-06-24 17:16:17 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-07-03 01:06:07 +08:00
|
|
|
"{func} cannot be invoked with the arguments ({args})",
|
|
|
|
{"func": typ.name, "args": ", ".join(args)},
|
2015-06-24 17:16:17 +08:00
|
|
|
node.func.loc, notes=valid_forms)
|
2015-06-26 16:16:08 +08:00
|
|
|
self.engine.process(diag)
|
2015-06-24 17:16:17 +08:00
|
|
|
|
2015-08-31 23:59:33 +08:00
|
|
|
def simple_form(info, arg_types=[], return_type=builtins.TNone()):
|
|
|
|
self._unify(node.type, return_type,
|
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) == len(arg_types) and len(node.keywords) == 0:
|
|
|
|
for index, arg_type in enumerate(arg_types):
|
|
|
|
self._unify(node.args[index].type, arg_type,
|
|
|
|
node.args[index].loc, None)
|
|
|
|
else:
|
|
|
|
diagnose([ valid_form(info) ])
|
|
|
|
|
2015-07-16 19:58:40 +08:00
|
|
|
if types.is_exn_constructor(typ):
|
2015-07-25 10:37:37 +08:00
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("{exn}() -> {exn}".format(exn=typ.name)),
|
|
|
|
valid_form("{exn}(message:str) -> {exn}".format(exn=typ.name)),
|
2016-07-06 12:03:54 +08:00
|
|
|
valid_form("{exn}(message:str, param1:numpy.int64) -> {exn}".format(exn=typ.name)),
|
|
|
|
valid_form("{exn}(message:str, param1:numpy.int64, "
|
|
|
|
"param2:numpy.int64) -> {exn}".format(exn=typ.name)),
|
|
|
|
valid_form("{exn}(message:str, param1:numpy.int64, "
|
|
|
|
"param2:numpy.int64, param3:numpy.int64) "
|
2015-07-25 10:37:37 +08:00
|
|
|
"-> {exn}".format(exn=typ.name)),
|
|
|
|
]
|
2015-07-16 19:58:40 +08:00
|
|
|
|
2015-07-25 10:37:37 +08:00
|
|
|
if len(node.args) == 0 and len(node.keywords) == 0:
|
|
|
|
pass # Default message, zeroes as parameters
|
|
|
|
elif len(node.args) >= 1 and len(node.args) <= 4 and len(node.keywords) == 0:
|
|
|
|
message, *params = node.args
|
|
|
|
|
|
|
|
self._unify(message.type, builtins.TStr(),
|
|
|
|
message.loc, None)
|
|
|
|
for param in params:
|
2015-12-10 23:05:49 +08:00
|
|
|
self._unify(param.type, builtins.TInt64(),
|
2015-07-25 10:37:37 +08:00
|
|
|
param.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-07-16 19:58:40 +08:00
|
|
|
|
2015-08-15 21:45:16 +08:00
|
|
|
self._unify(node.type, typ.instance,
|
2015-07-25 10:37:37 +08:00
|
|
|
node.loc, None)
|
2015-07-16 19:58:40 +08:00
|
|
|
elif types.is_builtin(typ, "bool"):
|
2015-06-24 17:16:17 +08:00
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("bool() -> bool"),
|
2015-06-26 16:16:08 +08:00
|
|
|
valid_form("bool(x:'a) -> bool")
|
2015-06-24 17:16:17 +08:00
|
|
|
]
|
2015-06-26 16:16:08 +08:00
|
|
|
|
|
|
|
if len(node.args) == 0 and len(node.keywords) == 0:
|
|
|
|
pass # False
|
|
|
|
elif len(node.args) == 1 and len(node.keywords) == 0:
|
|
|
|
arg, = node.args
|
|
|
|
pass # anything goes
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
|
|
|
|
self._unify(node.type, builtins.TBool(),
|
|
|
|
node.loc, None)
|
2016-11-21 22:01:11 +08:00
|
|
|
elif types.is_builtin(typ, "int") or \
|
|
|
|
types.is_builtin(typ, "int32") or types.is_builtin(typ, "int64"):
|
|
|
|
if types.is_builtin(typ, "int"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("int() -> numpy.int?"),
|
|
|
|
valid_form("int(x:'a) -> numpy.int? where 'a is numeric")
|
|
|
|
]
|
|
|
|
result_typ = builtins.TInt()
|
|
|
|
elif types.is_builtin(typ, "int32"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("numpy.int32() -> numpy.int32"),
|
|
|
|
valid_form("numpy.int32(x:'a) -> numpy.int32 where 'a is numeric")
|
|
|
|
]
|
|
|
|
result_typ = builtins.TInt32()
|
|
|
|
elif types.is_builtin(typ, "int64"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("numpy.int64() -> numpy.int64"),
|
|
|
|
valid_form("numpy.int64(x:'a) -> numpy.int64 where 'a is numeric")
|
|
|
|
]
|
|
|
|
result_typ = builtins.TInt64()
|
2015-06-26 16:16:08 +08:00
|
|
|
|
2016-11-21 22:01:11 +08:00
|
|
|
self._unify(node.type, result_typ,
|
2015-06-26 16:16:08 +08:00
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) == 0 and len(node.keywords) == 0:
|
|
|
|
pass # 0
|
2015-07-04 09:27:15 +08:00
|
|
|
elif len(node.args) == 1 and len(node.keywords) == 0 and \
|
|
|
|
types.is_var(node.args[0].type):
|
|
|
|
pass # undetermined yet
|
2015-06-26 16:16:08 +08:00
|
|
|
elif len(node.args) == 1 and len(node.keywords) == 0 and \
|
|
|
|
builtins.is_numeric(node.args[0].type):
|
2016-11-21 22:01:11 +08:00
|
|
|
self._unify(node.type, result_typ,
|
2015-06-26 16:16:08 +08:00
|
|
|
node.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-07-16 19:56:39 +08:00
|
|
|
elif types.is_builtin(typ, "float"):
|
2015-06-24 17:16:17 +08:00
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("float() -> float"),
|
|
|
|
valid_form("float(x:'a) -> float where 'a is numeric")
|
|
|
|
]
|
2015-06-26 16:16:08 +08:00
|
|
|
|
|
|
|
self._unify(node.type, builtins.TFloat(),
|
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) == 0 and len(node.keywords) == 0:
|
|
|
|
pass # 0.0
|
2015-07-04 09:27:15 +08:00
|
|
|
elif len(node.args) == 1 and len(node.keywords) == 0 and \
|
|
|
|
types.is_var(node.args[0].type):
|
|
|
|
pass # undetermined yet
|
2015-06-26 16:16:08 +08:00
|
|
|
elif len(node.args) == 1 and len(node.keywords) == 0 and \
|
|
|
|
builtins.is_numeric(node.args[0].type):
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2017-12-27 09:52:27 +08:00
|
|
|
elif types.is_builtin(typ, "str"):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"strings currently cannot be constructed", {},
|
|
|
|
node.loc)
|
|
|
|
self.engine.process(diag)
|
2020-07-26 08:07:03 +08:00
|
|
|
elif types.is_builtin(typ, "array"):
|
|
|
|
valid_forms = lambda: [
|
2020-08-09 09:44:54 +08:00
|
|
|
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable"),
|
|
|
|
valid_form("array(x:'a, dtype:'b) -> array(elt='b) where 'a is iterable")
|
2020-07-26 08:07:03 +08:00
|
|
|
]
|
2015-06-26 16:16:08 +08:00
|
|
|
|
2020-08-09 09:44:54 +08:00
|
|
|
explicit_dtype = None
|
|
|
|
keywords_acceptable = False
|
|
|
|
if len(node.keywords) == 0:
|
|
|
|
keywords_acceptable = True
|
|
|
|
elif len(node.keywords) == 1:
|
|
|
|
if node.keywords[0].arg == "dtype":
|
|
|
|
keywords_acceptable = True
|
|
|
|
explicit_dtype = node.keywords[0].value
|
|
|
|
if len(node.args) == 1 and keywords_acceptable:
|
2020-07-26 08:07:03 +08:00
|
|
|
arg, = node.args
|
2016-07-06 17:51:57 +08:00
|
|
|
|
compiler: Change type inference rules for empty array() calls
array([...]), the constructor for NumPy arrays, currently has the
status of some weird kind of macro in ARTIQ Python, as it needs
to determine the number of dimensions in the resulting array
type, which is a fixed type parameter on which inference cannot
be performed.
This leads to an ambiguity for empty lists, which could contain
elements of arbitrary type, including other lists (which would
add to the number of dimensions).
Previously, I had chosen to make array([]) to be of completely
indeterminate type for this reason. However, this is different
to how the call behaves in host NumPy, where this is a well-formed
call creating an empty 1D array (or 2D for array([[], []]), etc.).
This commit adds special matching for (recursive lists of) empty
ListT AST nodes to treat them as scalar dimensions, with the
element type still unknown.
This also happens to fix type inference for embedding empty 1D
NumPy arrays from host object attributes, although multi-dimensional
arrays will still require work (see GitHub #1633).
GitHub: Fixes #1626.
2021-03-15 06:36:40 +08:00
|
|
|
num_empty_dims = is_nested_empty_list(arg)
|
|
|
|
if num_empty_dims is not None:
|
|
|
|
# As a special case, following the behaviour of numpy.array (and
|
|
|
|
# repr() on ndarrays), consider empty lists to be exactly of the
|
|
|
|
# number of dimensions given, instead of potentially containing an
|
|
|
|
# unknown number of extra dimensions.
|
|
|
|
num_dims = num_empty_dims
|
|
|
|
|
|
|
|
# The ultimate element type will be TVar initially, but we might be
|
|
|
|
# able to resolve it from context.
|
|
|
|
elt = arg.type
|
|
|
|
for _ in range(num_dims):
|
|
|
|
assert builtins.is_list(elt)
|
|
|
|
elt = elt.find()["elt"]
|
|
|
|
else:
|
|
|
|
# In the absence of any other information (there currently isn't a way
|
|
|
|
# to specify any), assume that all iterables are expandable into a
|
|
|
|
# (runtime-checked) rectangular array of the innermost element type.
|
|
|
|
elt = arg.type
|
|
|
|
num_dims = 0
|
|
|
|
expected_dims = (node.type.find()["num_dims"].value
|
|
|
|
if builtins.is_array(node.type) else -1)
|
|
|
|
while True:
|
|
|
|
if num_dims == expected_dims:
|
|
|
|
# If we already know the number of dimensions of the result,
|
|
|
|
# stop so we can disambiguate the (innermost) element type of
|
|
|
|
# the argument if it is still unknown.
|
|
|
|
break
|
|
|
|
if types.is_var(elt):
|
|
|
|
# Can't make progress here because we don't know how many more
|
|
|
|
# dimensions might be "hidden" inside.
|
|
|
|
return
|
|
|
|
if not builtins.is_iterable(elt) or builtins.is_str(elt):
|
|
|
|
break
|
|
|
|
if builtins.is_array(elt):
|
|
|
|
num_dims += elt.find()["num_dims"].value
|
|
|
|
else:
|
|
|
|
num_dims += 1
|
|
|
|
elt = builtins.get_iterable_elt(elt)
|
2020-08-09 03:35:04 +08:00
|
|
|
|
2020-08-09 09:44:54 +08:00
|
|
|
if explicit_dtype is not None:
|
|
|
|
# TODO: Factor out type detection; support quoted type constructors
|
|
|
|
# (TList(TInt32), …)?
|
|
|
|
typ = explicit_dtype.type
|
|
|
|
if types.is_builtin(typ, "int32"):
|
|
|
|
elt = builtins.TInt32()
|
|
|
|
elif types.is_builtin(typ, "int64"):
|
|
|
|
elt = builtins.TInt64()
|
|
|
|
elif types.is_constructor(typ):
|
|
|
|
elt = typ.find().instance
|
|
|
|
else:
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"dtype argument of {builtin}() must be a valid constructor",
|
|
|
|
{"builtin": typ.find().name},
|
|
|
|
node.func.loc,
|
|
|
|
notes=[note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2020-08-09 03:35:04 +08:00
|
|
|
if num_dims == 0:
|
|
|
|
note = diagnostic.Diagnostic(
|
|
|
|
"note", "this expression has type {type}",
|
|
|
|
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
2020-07-26 08:07:03 +08:00
|
|
|
"the argument of {builtin}() must be of an iterable type",
|
|
|
|
{"builtin": typ.find().name},
|
2020-08-09 03:35:04 +08:00
|
|
|
node.func.loc,
|
|
|
|
notes=[note])
|
2020-07-26 08:07:03 +08:00
|
|
|
self.engine.process(diag)
|
2020-08-09 03:35:04 +08:00
|
|
|
return
|
|
|
|
|
|
|
|
self._unify(node.type,
|
|
|
|
builtins.TArray(elt, types.TValue(num_dims)),
|
|
|
|
node.loc, arg.loc)
|
2016-07-06 17:51:57 +08:00
|
|
|
else:
|
2020-07-26 08:07:03 +08:00
|
|
|
diagnose(valid_forms())
|
|
|
|
elif types.is_builtin(typ, "list"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("list() -> list(elt='a)"),
|
|
|
|
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
|
|
|
|
]
|
2015-06-26 16:16:08 +08:00
|
|
|
|
2020-07-26 08:07:03 +08:00
|
|
|
self._unify(node.type, builtins.TList(), node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) == 0 and len(node.keywords) == 0:
|
|
|
|
pass # []
|
2015-06-27 00:14:24 +08:00
|
|
|
elif len(node.args) == 1 and len(node.keywords) == 0:
|
|
|
|
arg, = node.args
|
|
|
|
|
|
|
|
if builtins.is_iterable(arg.type):
|
2015-07-16 20:35:46 +08:00
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"iterator returning elements of type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
loca),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"iterator returning elements of type {typeb}",
|
|
|
|
{"typeb": printer.name(typeb)},
|
|
|
|
locb)
|
|
|
|
]
|
|
|
|
self._unify(node.type.find().params["elt"],
|
2020-07-26 08:07:03 +08:00
|
|
|
arg.type.find().params["elt"],
|
2015-07-16 20:35:46 +08:00
|
|
|
node.loc, arg.loc, makenotes=makenotes)
|
2015-07-04 09:27:15 +08:00
|
|
|
elif types.is_var(arg.type):
|
|
|
|
pass # undetermined yet
|
2015-06-27 00:14:24 +08:00
|
|
|
else:
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"this expression has type {type}",
|
|
|
|
{"type": types.TypePrinter().name(arg.type)},
|
|
|
|
arg.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
2016-07-06 17:51:57 +08:00
|
|
|
"the argument of {builtin}() must be of an iterable type",
|
|
|
|
{"builtin": typ.find().name},
|
2015-06-27 00:14:24 +08:00
|
|
|
node.func.loc, notes=[note])
|
|
|
|
self.engine.process(diag)
|
2015-06-26 16:16:08 +08:00
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-07-16 19:56:39 +08:00
|
|
|
elif types.is_builtin(typ, "range"):
|
2015-06-26 23:53:20 +08:00
|
|
|
valid_forms = lambda: [
|
2016-07-06 12:44:21 +08:00
|
|
|
valid_form("range(max:numpy.int?) -> range(elt=numpy.int?)"),
|
|
|
|
valid_form("range(min:numpy.int?, max:numpy.int?) "
|
|
|
|
"-> range(elt=numpy.int?)"),
|
|
|
|
valid_form("range(min:numpy.int?, max:numpy.int?, "
|
|
|
|
"step:numpy.int?) -> range(elt=numpy.int?)"),
|
2015-06-26 23:53:20 +08:00
|
|
|
]
|
|
|
|
|
2015-09-03 07:46:54 +08:00
|
|
|
range_elt = builtins.TInt(types.TVar())
|
|
|
|
self._unify(node.type, builtins.TRange(range_elt),
|
2015-06-26 23:53:20 +08:00
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) in (1, 2, 3) and len(node.keywords) == 0:
|
|
|
|
for arg in node.args:
|
2015-09-03 07:46:54 +08:00
|
|
|
self._unify(arg.type, range_elt,
|
2015-06-26 23:53:20 +08:00
|
|
|
arg.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-07-16 19:56:39 +08:00
|
|
|
elif types.is_builtin(typ, "len"):
|
2015-06-24 17:16:17 +08:00
|
|
|
valid_forms = lambda: [
|
2016-07-06 12:44:21 +08:00
|
|
|
valid_form("len(x:'a) -> numpy.int?"),
|
2015-06-24 17:16:17 +08:00
|
|
|
]
|
2015-06-26 16:16:08 +08:00
|
|
|
|
|
|
|
if len(node.args) == 1 and len(node.keywords) == 0:
|
|
|
|
arg, = node.args
|
|
|
|
|
2015-08-10 22:06:55 +08:00
|
|
|
if builtins.is_range(arg.type):
|
|
|
|
self._unify(node.type, builtins.get_iterable_elt(arg.type),
|
|
|
|
node.loc, None)
|
2016-07-06 17:51:57 +08:00
|
|
|
elif builtins.is_listish(arg.type):
|
2015-08-10 22:06:55 +08:00
|
|
|
# TODO: should be ssize_t-sized
|
2015-12-10 23:05:49 +08:00
|
|
|
self._unify(node.type, builtins.TInt32(),
|
2015-08-10 22:06:55 +08:00
|
|
|
node.loc, None)
|
2015-07-04 09:27:15 +08:00
|
|
|
elif types.is_var(arg.type):
|
|
|
|
pass # undetermined yet
|
2015-06-26 23:53:20 +08:00
|
|
|
else:
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"this expression has type {type}",
|
|
|
|
{"type": types.TypePrinter().name(arg.type)},
|
|
|
|
arg.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"the argument of len() must be of an iterable type", {},
|
|
|
|
node.func.loc, notes=[note])
|
|
|
|
self.engine.process(diag)
|
2015-06-26 16:16:08 +08:00
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-07-16 19:56:39 +08:00
|
|
|
elif types.is_builtin(typ, "round"):
|
2015-06-24 17:16:17 +08:00
|
|
|
valid_forms = lambda: [
|
2016-07-06 12:44:21 +08:00
|
|
|
valid_form("round(x:float) -> numpy.int?"),
|
2015-06-24 17:16:17 +08:00
|
|
|
]
|
2015-06-26 16:16:08 +08:00
|
|
|
|
|
|
|
self._unify(node.type, builtins.TInt(),
|
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) == 1 and len(node.keywords) == 0:
|
|
|
|
arg, = node.args
|
|
|
|
|
|
|
|
self._unify(arg.type, builtins.TFloat(),
|
|
|
|
arg.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2019-04-13 07:43:45 +08:00
|
|
|
elif types.is_builtin(typ, "abs"):
|
|
|
|
fn = typ.name
|
|
|
|
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("abs(x:numpy.int?) -> numpy.int?"),
|
|
|
|
valid_form("abs(x:float) -> float")
|
|
|
|
]
|
|
|
|
|
|
|
|
if len(node.args) == 1 and len(node.keywords) == 0:
|
|
|
|
(arg,) = node.args
|
|
|
|
if builtins.is_int(arg.type) or builtins.is_float(arg.type):
|
|
|
|
self._unify(arg.type, node.type,
|
|
|
|
arg.loc, node.loc)
|
|
|
|
elif types.is_var(arg.type):
|
|
|
|
pass # undetermined yet
|
|
|
|
else:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"the arguments of abs() must be of a numeric type", {},
|
|
|
|
node.func.loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2016-06-22 09:09:41 +08:00
|
|
|
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):
|
|
|
|
fn = typ.name
|
|
|
|
|
|
|
|
valid_forms = lambda: [
|
2016-07-06 12:44:21 +08:00
|
|
|
valid_form("{}(x:numpy.int?, y:numpy.int?) -> numpy.int?".format(fn)),
|
2016-06-22 09:09:41 +08:00
|
|
|
valid_form("{}(x:float, y:float) -> float".format(fn))
|
|
|
|
]
|
|
|
|
|
|
|
|
if len(node.args) == 2 and len(node.keywords) == 0:
|
|
|
|
arg0, arg1 = node.args
|
|
|
|
|
|
|
|
self._unify(arg0.type, arg1.type,
|
|
|
|
arg0.loc, arg1.loc)
|
|
|
|
|
|
|
|
if builtins.is_int(arg0.type) or builtins.is_float(arg0.type):
|
|
|
|
self._unify(arg0.type, node.type,
|
|
|
|
arg0.loc, node.loc)
|
|
|
|
elif types.is_var(arg0.type):
|
|
|
|
pass # undetermined yet
|
|
|
|
else:
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"this expression has type {type}",
|
|
|
|
{"type": types.TypePrinter().name(arg0.type)},
|
|
|
|
arg0.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"the arguments of {fn}() must be of a numeric type",
|
|
|
|
{"fn": fn},
|
|
|
|
node.func.loc, notes=[note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2016-02-16 05:17:54 +08:00
|
|
|
elif types.is_builtin(typ, "print"):
|
2015-07-22 03:32:10 +08:00
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("print(args...) -> None"),
|
|
|
|
]
|
|
|
|
|
|
|
|
self._unify(node.type, builtins.TNone(),
|
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.keywords) == 0:
|
|
|
|
# We can print any arguments.
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2016-07-07 14:33:30 +08:00
|
|
|
elif types.is_builtin(typ, "make_array"):
|
|
|
|
valid_forms = lambda: [
|
2020-08-10 06:30:25 +08:00
|
|
|
valid_form("numpy.full(count:int32, value:'a) -> array(elt='a, num_dims=1)"),
|
|
|
|
valid_form("numpy.full(shape:(int32,)*'b, value:'a) -> array(elt='a, num_dims='b)"),
|
2016-07-07 14:33:30 +08:00
|
|
|
]
|
|
|
|
|
|
|
|
if len(node.args) == 2 and len(node.keywords) == 0:
|
|
|
|
arg0, arg1 = node.args
|
|
|
|
|
2020-08-10 06:30:25 +08:00
|
|
|
if types.is_var(arg0.type):
|
|
|
|
return # undetermined yet
|
|
|
|
elif types.is_tuple(arg0.type):
|
|
|
|
num_dims = len(arg0.type.find().elts)
|
|
|
|
self._unify(arg0.type, types.TTuple([builtins.TInt32()] * num_dims),
|
|
|
|
arg0.loc, None)
|
|
|
|
else:
|
|
|
|
num_dims = 1
|
|
|
|
self._unify(arg0.type, builtins.TInt32(),
|
|
|
|
arg0.loc, None)
|
|
|
|
|
|
|
|
self._unify(node.type, builtins.TArray(num_dims=num_dims),
|
|
|
|
node.loc, None)
|
2016-07-07 14:33:30 +08:00
|
|
|
self._unify(arg1.type, node.type.find()["elt"],
|
|
|
|
arg1.loc, None)
|
2020-08-03 06:41:23 +08:00
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
elif types.is_builtin(typ, "numpy.transpose"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("transpose(x: array(elt='a, num_dims=1)) -> array(elt='a, num_dims=1)"),
|
|
|
|
valid_form("transpose(x: array(elt='a, num_dims=2)) -> array(elt='a, num_dims=2)")
|
|
|
|
]
|
|
|
|
|
|
|
|
if len(node.args) == 1 and len(node.keywords) == 0:
|
|
|
|
arg, = node.args
|
|
|
|
|
|
|
|
if types.is_var(arg.type):
|
|
|
|
pass # undetermined yet
|
|
|
|
elif not builtins.is_array(arg.type):
|
|
|
|
note = diagnostic.Diagnostic(
|
|
|
|
"note", "this expression has type {type}",
|
|
|
|
{"type": types.TypePrinter().name(arg.type)}, arg.loc)
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"the argument of {builtin}() must be an array",
|
|
|
|
{"builtin": typ.find().name},
|
|
|
|
node.func.loc,
|
|
|
|
notes=[note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
else:
|
|
|
|
num_dims = arg.type.find()["num_dims"].value
|
|
|
|
if num_dims not in (1, 2):
|
|
|
|
note = diagnostic.Diagnostic(
|
|
|
|
"note", "argument is {num_dims}-dimensional",
|
|
|
|
{"num_dims": num_dims}, arg.loc)
|
|
|
|
diag = diagnostic.Diagnostic(
|
|
|
|
"error",
|
|
|
|
"{builtin}() is currently only supported for up to "
|
|
|
|
"two-dimensional arrays", {"builtin": typ.find().name},
|
|
|
|
node.func.loc,
|
|
|
|
notes=[note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
else:
|
|
|
|
self._unify(node.type, arg.type, node.loc, None)
|
2016-07-07 14:33:30 +08:00
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2016-02-16 05:17:54 +08:00
|
|
|
elif types.is_builtin(typ, "rtio_log"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("rtio_log(channel:str, args...) -> None"),
|
|
|
|
]
|
|
|
|
|
|
|
|
self._unify(node.type, builtins.TNone(),
|
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
if len(node.args) >= 1 and len(node.keywords) == 0:
|
|
|
|
arg = node.args[0]
|
|
|
|
|
|
|
|
self._unify(arg.type, builtins.TStr(),
|
|
|
|
arg.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-08-31 23:59:33 +08:00
|
|
|
elif types.is_builtin(typ, "now"):
|
|
|
|
simple_form("now() -> float",
|
|
|
|
[], builtins.TFloat())
|
|
|
|
elif types.is_builtin(typ, "delay"):
|
|
|
|
simple_form("delay(time:float) -> None",
|
|
|
|
[builtins.TFloat()])
|
|
|
|
elif types.is_builtin(typ, "at"):
|
|
|
|
simple_form("at(time:float) -> None",
|
|
|
|
[builtins.TFloat()])
|
|
|
|
elif types.is_builtin(typ, "now_mu"):
|
2016-07-06 12:03:54 +08:00
|
|
|
simple_form("now_mu() -> numpy.int64",
|
2015-12-10 23:05:49 +08:00
|
|
|
[], builtins.TInt64())
|
2015-08-31 23:59:33 +08:00
|
|
|
elif types.is_builtin(typ, "delay_mu"):
|
2016-07-06 12:03:54 +08:00
|
|
|
simple_form("delay_mu(time_mu:numpy.int64) -> None",
|
2015-12-10 23:05:49 +08:00
|
|
|
[builtins.TInt64()])
|
2015-08-31 23:59:33 +08:00
|
|
|
elif types.is_builtin(typ, "at_mu"):
|
2016-07-06 12:03:54 +08:00
|
|
|
simple_form("at_mu(time_mu:numpy.int64) -> None",
|
2015-12-10 23:05:49 +08:00
|
|
|
[builtins.TInt64()])
|
2015-08-19 13:44:09 +08:00
|
|
|
elif types.is_constructor(typ):
|
|
|
|
# An user-defined class.
|
|
|
|
self._unify(node.type, typ.find().instance,
|
|
|
|
node.loc, None)
|
2016-03-29 05:25:40 +08:00
|
|
|
elif types.is_builtin(typ, "kernel"):
|
|
|
|
# Ignored.
|
|
|
|
self._unify(node.type, builtins.TNone(),
|
|
|
|
node.loc, None)
|
2023-10-05 14:35:50 +08:00
|
|
|
elif types.is_builtin(typ, "subkernel_await"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("subkernel_await(f: subkernel) -> f return type"),
|
|
|
|
valid_form("subkernel_await(f: subkernel, timeout: numpy.int64) -> f return type")
|
|
|
|
]
|
|
|
|
if 1 <= len(node.args) <= 2:
|
|
|
|
arg0 = node.args[0].type
|
|
|
|
if types.is_var(arg0):
|
|
|
|
pass # undetermined yet
|
|
|
|
else:
|
|
|
|
if types.is_method(arg0):
|
|
|
|
fn = types.get_method_function(arg0)
|
|
|
|
elif types.is_function(arg0) or types.is_subkernel(arg0):
|
|
|
|
fn = arg0
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
self._unify(node.type, fn.ret,
|
|
|
|
node.loc, None)
|
|
|
|
if len(node.args) == 2:
|
|
|
|
arg1 = node.args[1]
|
|
|
|
if types.is_var(arg1.type):
|
|
|
|
pass
|
|
|
|
elif builtins.is_int(arg1.type):
|
|
|
|
# promote to TInt64
|
|
|
|
self._unify(arg1.type, builtins.TInt64(),
|
|
|
|
arg1.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
elif types.is_builtin(typ, "subkernel_preload"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("subkernel_preload(f: subkernel) -> None")
|
|
|
|
]
|
|
|
|
if len(node.args) == 1:
|
|
|
|
arg0 = node.args[0].type
|
|
|
|
if types.is_var(arg0):
|
|
|
|
pass # undetermined yet
|
|
|
|
else:
|
|
|
|
if types.is_method(arg0):
|
|
|
|
fn = types.get_method_function(arg0)
|
|
|
|
elif types.is_function(arg0) or types.is_subkernel(arg0):
|
|
|
|
fn = arg0
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
self._unify(node.type, fn.ret,
|
|
|
|
node.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2024-01-26 16:02:28 +08:00
|
|
|
elif types.is_builtin(typ, "subkernel_send"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("subkernel_send(dest: numpy.int?, name: str, value: V) -> None"),
|
|
|
|
]
|
|
|
|
self._unify(node.type, builtins.TNone(),
|
|
|
|
node.loc, None)
|
|
|
|
if len(node.args) == 3:
|
|
|
|
arg0 = node.args[0]
|
|
|
|
if types.is_var(arg0.type):
|
|
|
|
pass # undetermined yet
|
|
|
|
else:
|
|
|
|
if builtins.is_int(arg0.type):
|
|
|
|
self._unify(arg0.type, builtins.TInt8(),
|
|
|
|
arg0.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
arg1 = node.args[1]
|
|
|
|
self._unify(arg1.type, builtins.TStr(),
|
|
|
|
arg1.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
elif types.is_builtin(typ, "subkernel_recv"):
|
|
|
|
valid_forms = lambda: [
|
|
|
|
valid_form("subkernel_recv(name: str, value_type: type) -> value_type"),
|
|
|
|
valid_form("subkernel_recv(name: str, value_type: type, timeout: numpy.int64) -> value_type"),
|
|
|
|
]
|
|
|
|
if 2 <= len(node.args) <= 3:
|
|
|
|
arg0 = node.args[0]
|
|
|
|
if types.is_var(arg0.type):
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
self._unify(arg0.type, builtins.TStr(),
|
|
|
|
arg0.loc, None)
|
|
|
|
arg1 = node.args[1]
|
|
|
|
if types.is_var(arg1.type):
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
self._unify(node.type, arg1.value,
|
|
|
|
node.loc, None)
|
|
|
|
if len(node.args) == 3:
|
|
|
|
arg2 = node.args[2]
|
|
|
|
if types.is_var(arg2.type):
|
|
|
|
pass
|
|
|
|
elif builtins.is_int(arg2.type):
|
|
|
|
# promote to TInt64
|
|
|
|
self._unify(arg2.type, builtins.TInt64(),
|
|
|
|
arg2.loc, None)
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
|
|
|
else:
|
|
|
|
diagnose(valid_forms())
|
2015-08-19 13:44:09 +08:00
|
|
|
else:
|
2015-12-10 22:25:15 +08:00
|
|
|
assert False
|
2015-06-24 17:16:17 +08:00
|
|
|
|
2015-06-15 21:55:13 +08:00
|
|
|
def visit_CallT(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
for (sigil_loc, vararg) in ((node.star_loc, node.starargs),
|
|
|
|
(node.dstar_loc, node.kwargs)):
|
|
|
|
if vararg:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"variadic arguments are not supported", {},
|
|
|
|
sigil_loc, [vararg.loc])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-08-15 23:04:12 +08:00
|
|
|
typ = node.func.type.find()
|
|
|
|
|
|
|
|
if types.is_var(typ):
|
2015-06-15 21:55:13 +08:00
|
|
|
return # not enough info yet
|
2015-08-15 23:04:12 +08:00
|
|
|
elif types.is_builtin(typ):
|
2015-06-26 16:16:08 +08:00
|
|
|
return self.visit_builtin_call(node)
|
2016-04-26 06:05:32 +08:00
|
|
|
elif types.is_rpc(typ):
|
|
|
|
self._unify(node.type, typ.ret,
|
|
|
|
node.loc, None)
|
|
|
|
return
|
2015-08-19 13:39:22 +08:00
|
|
|
elif not (types.is_function(typ) or types.is_method(typ)):
|
2015-06-15 21:55:13 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"cannot call this expression of type {type}",
|
2015-08-15 23:04:12 +08:00
|
|
|
{"type": types.TypePrinter().name(typ)},
|
2015-06-15 21:55:13 +08:00
|
|
|
node.func.loc, [])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-08-19 13:39:22 +08:00
|
|
|
if types.is_function(typ):
|
|
|
|
typ_arity = typ.arity()
|
|
|
|
typ_args = typ.args
|
|
|
|
typ_optargs = typ.optargs
|
|
|
|
typ_ret = typ.ret
|
2023-10-05 14:35:50 +08:00
|
|
|
typ_func = typ
|
2015-08-19 13:39:22 +08:00
|
|
|
else:
|
2016-05-16 23:59:09 +08:00
|
|
|
typ_self = types.get_method_self(typ)
|
|
|
|
typ_func = types.get_method_function(typ)
|
|
|
|
if types.is_var(typ_func):
|
2015-12-02 21:48:14 +08:00
|
|
|
return # not enough info yet
|
2016-05-16 23:59:09 +08:00
|
|
|
elif types.is_rpc(typ_func):
|
|
|
|
self._unify(node.type, typ_func.ret,
|
2016-04-26 06:05:32 +08:00
|
|
|
node.loc, None)
|
|
|
|
return
|
2016-05-16 23:59:09 +08:00
|
|
|
elif typ_func.arity() == 0:
|
2016-05-09 20:25:47 +08:00
|
|
|
return # error elsewhere
|
2015-12-02 21:48:14 +08:00
|
|
|
|
2016-05-16 23:59:09 +08:00
|
|
|
method_args = list(typ_func.args.items())
|
|
|
|
|
|
|
|
self_arg_name, self_arg_type = method_args[0]
|
|
|
|
self._unify(self_arg_type, typ_self,
|
|
|
|
node.loc, None)
|
|
|
|
|
|
|
|
typ_arity = typ_func.arity() - 1
|
|
|
|
typ_args = OrderedDict(method_args[1:])
|
|
|
|
typ_optargs = typ_func.optargs
|
|
|
|
typ_ret = typ_func.ret
|
2015-08-19 13:39:22 +08:00
|
|
|
|
2015-08-09 02:06:13 +08:00
|
|
|
passed_args = dict()
|
2015-06-15 21:55:13 +08:00
|
|
|
|
2015-08-19 13:39:22 +08:00
|
|
|
if len(node.args) > typ_arity:
|
2015-06-15 21:55:13 +08:00
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"extraneous argument(s)", {},
|
2015-08-19 13:39:22 +08:00
|
|
|
node.args[typ_arity].loc.join(node.args[-1].loc))
|
2015-06-15 21:55:13 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"this function of type {type} accepts at most {num} arguments",
|
|
|
|
{"type": types.TypePrinter().name(node.func.type),
|
2015-08-19 13:39:22 +08:00
|
|
|
"num": typ_arity},
|
2015-06-15 21:55:13 +08:00
|
|
|
node.func.loc, [], [note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2020-11-11 05:24:04 +08:00
|
|
|
# Array broadcasting for functions explicitly marked as such.
|
|
|
|
if len(node.args) == typ_arity and types.is_broadcast_across_arrays(typ):
|
|
|
|
if typ_arity == 1:
|
|
|
|
arg_type = node.args[0].type.find()
|
|
|
|
if builtins.is_array(arg_type):
|
|
|
|
typ_arg, = typ_args.values()
|
|
|
|
self._unify(typ_arg, arg_type["elt"], node.args[0].loc, None)
|
|
|
|
self._unify(node.type, builtins.TArray(typ_ret, arg_type["num_dims"]),
|
|
|
|
node.loc, None)
|
|
|
|
return
|
|
|
|
elif typ_arity == 2:
|
|
|
|
if any(builtins.is_array(arg.type) for arg in node.args):
|
|
|
|
ret, arg0, arg1 = self._coerce_binary_broadcast_op(
|
|
|
|
node.args[0], node.args[1], lambda t: typ_ret, node.loc)
|
|
|
|
node.args[0] = self._coerce_one(arg0, node.args[0],
|
|
|
|
other_node=node.args[1])
|
|
|
|
node.args[1] = self._coerce_one(arg1, node.args[1],
|
|
|
|
other_node=node.args[0])
|
|
|
|
self._unify(node.type, ret, node.loc, None)
|
|
|
|
return
|
2023-10-05 14:35:50 +08:00
|
|
|
if types.is_subkernel(typ_func) and typ_func.sid not in self.subkernel_arg_types:
|
|
|
|
self.subkernel_arg_types[typ_func.sid] = []
|
2020-08-03 08:29:39 +08:00
|
|
|
|
2015-06-15 21:55:13 +08:00
|
|
|
for actualarg, (formalname, formaltyp) in \
|
2015-08-19 13:39:22 +08:00
|
|
|
zip(node.args, list(typ_args.items()) + list(typ_optargs.items())):
|
2015-06-15 21:55:13 +08:00
|
|
|
self._unify(actualarg.type, formaltyp,
|
|
|
|
actualarg.loc, None)
|
2015-08-09 02:06:13 +08:00
|
|
|
passed_args[formalname] = actualarg.loc
|
2023-10-05 14:35:50 +08:00
|
|
|
if types.is_subkernel(typ_func):
|
|
|
|
if types.is_instance(actualarg.type):
|
|
|
|
# objects cannot be passed to subkernels, as rpc code doesn't support them
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"argument '{name}' of type: {typ} is not supported in subkernels",
|
|
|
|
{"name": formalname, "typ": actualarg.type},
|
|
|
|
actualarg.loc, [])
|
|
|
|
self.engine.process(diag)
|
|
|
|
self.subkernel_arg_types[typ_func.sid].append((formalname, formaltyp))
|
2015-06-15 21:55:13 +08:00
|
|
|
|
|
|
|
for keyword in node.keywords:
|
|
|
|
if keyword.arg in passed_args:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
2015-08-09 02:06:13 +08:00
|
|
|
"the argument '{name}' has been passed earlier as positional",
|
2015-06-15 21:55:13 +08:00
|
|
|
{"name": keyword.arg},
|
2015-08-09 02:06:13 +08:00
|
|
|
keyword.arg_loc, [passed_args[keyword.arg]])
|
2015-06-15 21:55:13 +08:00
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-08-19 13:39:22 +08:00
|
|
|
if keyword.arg in typ_args:
|
|
|
|
self._unify(keyword.value.type, typ_args[keyword.arg],
|
2015-06-15 21:55:13 +08:00
|
|
|
keyword.value.loc, None)
|
2015-08-19 13:39:22 +08:00
|
|
|
elif keyword.arg in typ_optargs:
|
|
|
|
self._unify(keyword.value.type, typ_optargs[keyword.arg],
|
2015-06-15 21:55:13 +08:00
|
|
|
keyword.value.loc, None)
|
2018-02-21 19:37:12 +08:00
|
|
|
else:
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"extraneous argument", {},
|
|
|
|
keyword.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"this function of type {type} does not accept argument '{name}'",
|
|
|
|
{"type": types.TypePrinter().name(node.func.type),
|
|
|
|
"name": keyword.arg},
|
|
|
|
node.func.loc, [], [note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
2015-08-11 05:41:31 +08:00
|
|
|
passed_args[keyword.arg] = keyword.arg_loc
|
2015-06-15 21:55:13 +08:00
|
|
|
|
2015-08-19 13:39:22 +08:00
|
|
|
for formalname in typ_args:
|
2023-10-05 14:35:50 +08:00
|
|
|
if formalname not in passed_args and not node.remote_fn:
|
2015-06-15 21:55:13 +08:00
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"the called function is of type {type}",
|
|
|
|
{"type": types.TypePrinter().name(node.func.type)},
|
|
|
|
node.func.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"mandatory argument '{name}' is not passed",
|
|
|
|
{"name": formalname},
|
|
|
|
node.begin_loc.join(node.end_loc), [], [note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-08-19 13:39:22 +08:00
|
|
|
self._unify(node.type, typ_ret,
|
2015-06-15 22:16:44 +08:00
|
|
|
node.loc, None)
|
|
|
|
|
2015-06-15 16:30:50 +08:00
|
|
|
def visit_LambdaT(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
signature_type = self._type_from_arguments(node.args, node.body.type)
|
|
|
|
if signature_type:
|
|
|
|
self._unify(node.type, signature_type,
|
|
|
|
node.loc, None)
|
|
|
|
|
2015-05-29 14:53:24 +08:00
|
|
|
def visit_Assign(self, node):
|
2015-06-13 16:03:33 +08:00
|
|
|
self.generic_visit(node)
|
2015-07-15 11:33:44 +08:00
|
|
|
for target in node.targets:
|
|
|
|
self._unify(target.type, node.value.type,
|
|
|
|
target.loc, node.value.loc)
|
2015-05-29 14:53:24 +08:00
|
|
|
|
2015-06-02 13:53:11 +08:00
|
|
|
def visit_AugAssign(self, node):
|
2015-06-13 16:03:33 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-14 18:10:32 +08:00
|
|
|
coerced = self._coerce_binop(node.op, node.target, node.value)
|
|
|
|
if coerced:
|
|
|
|
return_type, target_type, value_type = coerced
|
|
|
|
|
2016-11-21 21:08:03 +08:00
|
|
|
if isinstance(node.value, asttyped.CoerceT):
|
|
|
|
orig_value_type = node.value.value.type
|
|
|
|
else:
|
|
|
|
orig_value_type = node.value.type
|
|
|
|
|
2015-06-14 18:10:32 +08:00
|
|
|
try:
|
2016-11-21 21:08:03 +08:00
|
|
|
node.target.type.unify(return_type)
|
2015-06-14 18:10:32 +08:00
|
|
|
except types.UnificationError as e:
|
|
|
|
printer = types.TypePrinter()
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typec}",
|
2016-11-21 21:08:03 +08:00
|
|
|
{"typec": printer.name(orig_value_type)},
|
2015-06-14 18:10:32 +08:00
|
|
|
node.value.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
2016-11-21 21:08:03 +08:00
|
|
|
"the result of this operation has type {typeb}, "
|
|
|
|
"which cannot be assigned to a left-hand side of type {typea}",
|
2015-06-14 18:10:32 +08:00
|
|
|
{"typea": printer.name(node.target.type),
|
2016-11-21 21:08:03 +08:00
|
|
|
"typeb": printer.name(return_type)},
|
2015-06-14 18:10:32 +08:00
|
|
|
node.op.loc, [node.target.loc], [note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
try:
|
2016-11-21 21:08:03 +08:00
|
|
|
node.target.type.unify(target_type)
|
2015-06-14 18:10:32 +08:00
|
|
|
except types.UnificationError as e:
|
|
|
|
printer = types.TypePrinter()
|
|
|
|
note = diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typec}",
|
2016-11-21 21:08:03 +08:00
|
|
|
{"typec": printer.name(orig_value_type)},
|
2015-06-14 18:10:32 +08:00
|
|
|
node.value.loc)
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
2016-11-21 21:08:03 +08:00
|
|
|
"this operation requires the left-hand side of type {typea} "
|
|
|
|
"to be coerced to {typeb}, which cannot be done",
|
2015-06-14 18:10:32 +08:00
|
|
|
{"typea": printer.name(node.target.type),
|
2016-11-21 21:08:03 +08:00
|
|
|
"typeb": printer.name(target_type)},
|
2015-06-14 18:10:32 +08:00
|
|
|
node.op.loc, [node.target.loc], [note])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
node.value = self._coerce_one(value_type, node.value, other_node=node.target)
|
2015-06-02 13:53:11 +08:00
|
|
|
|
2015-12-16 15:33:15 +08:00
|
|
|
def visit_ForT(self, node):
|
2015-06-13 17:07:46 +08:00
|
|
|
old_in_loop, self.in_loop = self.in_loop, True
|
2015-06-13 16:03:33 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-13 17:07:46 +08:00
|
|
|
self.in_loop = old_in_loop
|
2015-06-26 23:53:20 +08:00
|
|
|
self._unify_iterable(node.target, node.iter)
|
2015-06-04 19:12:41 +08:00
|
|
|
|
2015-06-13 17:07:46 +08:00
|
|
|
def visit_While(self, node):
|
|
|
|
old_in_loop, self.in_loop = self.in_loop, True
|
|
|
|
self.generic_visit(node)
|
|
|
|
self.in_loop = old_in_loop
|
|
|
|
|
|
|
|
def visit_Break(self, node):
|
|
|
|
if not self.in_loop:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"break statement outside of a loop", {},
|
|
|
|
node.keyword_loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
|
|
|
def visit_Continue(self, node):
|
|
|
|
if not self.in_loop:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"continue statement outside of a loop", {},
|
|
|
|
node.keyword_loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
2016-01-05 00:11:03 +08:00
|
|
|
def visit_withitemT(self, node):
|
2015-06-15 04:13:41 +08:00
|
|
|
self.generic_visit(node)
|
2015-09-03 07:46:09 +08:00
|
|
|
|
|
|
|
typ = node.context_expr.type
|
2016-02-22 21:24:43 +08:00
|
|
|
if (types.is_builtin(typ, "interleave") or types.is_builtin(typ, "sequential") or
|
2020-12-08 13:24:58 +08:00
|
|
|
types.is_builtin(typ, "parallel")):
|
2016-01-05 00:11:03 +08:00
|
|
|
# builtin context managers
|
|
|
|
if node.optional_vars is not None:
|
|
|
|
self._unify(node.optional_vars.type, builtins.TNone(),
|
|
|
|
node.optional_vars.loc, None)
|
|
|
|
elif types.is_instance(typ) or types.is_constructor(typ):
|
|
|
|
# user-defined context managers
|
|
|
|
self._unify_attribute(result_type=node.enter_type, value_node=node.context_expr,
|
|
|
|
attr_name='__enter__', attr_loc=None, loc=node.loc)
|
|
|
|
self._unify_attribute(result_type=node.exit_type, value_node=node.context_expr,
|
|
|
|
attr_name='__exit__', attr_loc=None, loc=node.loc)
|
|
|
|
|
|
|
|
printer = types.TypePrinter()
|
|
|
|
|
|
|
|
def check_callback(attr_name, typ, arity):
|
|
|
|
if types.is_var(typ):
|
|
|
|
return
|
|
|
|
|
|
|
|
if not (types.is_method(typ) or types.is_function(typ)):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"attribute '{attr}' of type {manager_type} must be a function",
|
|
|
|
{"attr": attr_name,
|
|
|
|
"manager_type": printer.name(node.context_expr.type)},
|
|
|
|
node.context_expr.loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
if types.is_method(typ):
|
|
|
|
typ = types.get_method_function(typ).find()
|
|
|
|
else:
|
|
|
|
typ = typ.find()
|
|
|
|
|
|
|
|
if not (len(typ.args) == arity and len(typ.optargs) == 0):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"function '{attr}{attr_type}' must accept "
|
|
|
|
"{arity} positional argument{s} and no optional arguments",
|
|
|
|
{"attr": attr_name,
|
|
|
|
"attr_type": printer.name(typ),
|
|
|
|
"arity": arity, "s": "s" if arity > 1 else ""},
|
|
|
|
node.context_expr.loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
|
|
|
for formal_arg_name in list(typ.args)[1:]:
|
|
|
|
formal_arg_type = typ.args[formal_arg_name]
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"exception handling via context managers is not supported; "
|
|
|
|
"the argument '{arg}' of function '{attr}{attr_type}' "
|
|
|
|
"will always be None",
|
|
|
|
{"arg": formal_arg_name,
|
|
|
|
"attr": attr_name,
|
|
|
|
"attr_type": printer.name(typ)},
|
|
|
|
loca),
|
|
|
|
]
|
|
|
|
|
|
|
|
self._unify(formal_arg_type, builtins.TNone(),
|
|
|
|
node.context_expr.loc, None,
|
|
|
|
makenotes=makenotes)
|
|
|
|
|
|
|
|
check_callback('__enter__', node.enter_type, 1)
|
|
|
|
check_callback('__exit__', node.exit_type, 4)
|
|
|
|
|
|
|
|
if node.optional_vars is not None:
|
|
|
|
if types.is_method(node.exit_type):
|
|
|
|
var_type = types.get_method_function(node.exit_type).find().ret
|
|
|
|
else:
|
|
|
|
var_type = node.exit_type.find().ret
|
|
|
|
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
loca),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"context manager with an '__enter__' method returning {typeb}",
|
|
|
|
{"typeb": printer.name(typeb)},
|
|
|
|
locb)
|
|
|
|
]
|
|
|
|
|
|
|
|
self._unify(node.optional_vars.type, var_type,
|
|
|
|
node.optional_vars.loc, node.context_expr.loc,
|
|
|
|
makenotes=makenotes)
|
|
|
|
|
2016-01-05 12:21:46 +08:00
|
|
|
elif not types.is_var(typ):
|
2015-06-15 04:13:41 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"value of type {type} cannot act as a context manager",
|
2015-09-03 07:46:09 +08:00
|
|
|
{"type": types.TypePrinter().name(typ)},
|
2015-06-15 04:13:41 +08:00
|
|
|
node.context_expr.loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
2016-01-04 21:26:03 +08:00
|
|
|
def visit_With(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
for item_node in node.items:
|
|
|
|
typ = item_node.context_expr.type.find()
|
2016-02-22 21:51:08 +08:00
|
|
|
if (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "interleave") or
|
|
|
|
types.is_builtin(typ, "sequential")) and len(node.items) != 1:
|
2016-01-04 21:26:03 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"the '{kind}' context manager must be the only one in a 'with' statement",
|
|
|
|
{"kind": typ.name},
|
|
|
|
node.keyword_loc.join(node.colon_loc))
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
2015-06-29 05:31:06 +08:00
|
|
|
def visit_ExceptHandlerT(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
2015-07-25 10:37:37 +08:00
|
|
|
if node.filter is not None:
|
|
|
|
if not types.is_exn_constructor(node.filter.type):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"this expression must refer to an exception constructor",
|
|
|
|
{"type": types.TypePrinter().name(node.filter.type)},
|
|
|
|
node.filter.loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
else:
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"expression of type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
loca),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"constructor of an exception of type {typeb}",
|
|
|
|
{"typeb": printer.name(typeb)},
|
|
|
|
locb)
|
|
|
|
]
|
2015-12-31 21:54:54 +08:00
|
|
|
self._unify(node.name_type, node.filter.type.instance,
|
2015-07-25 10:37:37 +08:00
|
|
|
node.name_loc, node.filter.loc, makenotes)
|
2015-06-29 05:31:06 +08:00
|
|
|
|
2015-06-15 16:30:50 +08:00
|
|
|
def _type_from_arguments(self, node, ret):
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
for (sigil_loc, vararg) in ((node.star_loc, node.vararg),
|
|
|
|
(node.dstar_loc, node.kwarg)):
|
|
|
|
if vararg:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"variadic arguments are not supported", {},
|
|
|
|
sigil_loc, [vararg.loc])
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
|
|
|
def extract_args(arg_nodes):
|
|
|
|
args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes]
|
|
|
|
return OrderedDict(args)
|
|
|
|
|
|
|
|
return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]),
|
2015-06-15 21:55:13 +08:00
|
|
|
extract_args(node.args[len(node.args) - len(node.defaults):]),
|
2015-06-15 16:30:50 +08:00
|
|
|
ret)
|
|
|
|
|
|
|
|
def visit_arguments(self, node):
|
|
|
|
self.generic_visit(node)
|
2015-07-16 22:26:31 +08:00
|
|
|
for arg, default in zip(node.args[len(node.args) - len(node.defaults):], node.defaults):
|
2015-06-15 16:30:50 +08:00
|
|
|
self._unify(arg.type, default.type,
|
|
|
|
arg.loc, default.loc)
|
|
|
|
|
2015-06-13 17:07:46 +08:00
|
|
|
def visit_FunctionDefT(self, node):
|
2015-08-07 12:54:35 +08:00
|
|
|
for index, decorator in enumerate(node.decorator_list):
|
2021-07-02 16:28:47 +08:00
|
|
|
def eval_attr(attr):
|
|
|
|
if isinstance(attr.value, asttyped.QuoteT):
|
|
|
|
return getattr(attr.value.value, attr.attr)
|
|
|
|
return getattr(eval_attr(attr.value), attr.attr)
|
|
|
|
if isinstance(decorator, asttyped.AttributeT):
|
|
|
|
decorator = eval_attr(decorator)
|
|
|
|
if id(decorator) == id(kernel) or \
|
|
|
|
types.is_builtin(decorator.type, "kernel") or \
|
2016-03-29 05:25:40 +08:00
|
|
|
isinstance(decorator, asttyped.CallT) and \
|
|
|
|
types.is_builtin(decorator.func.type, "kernel"):
|
2015-08-07 12:54:35 +08:00
|
|
|
continue
|
|
|
|
|
2015-06-15 16:30:50 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"decorators are not supported", {},
|
2015-12-18 22:58:17 +08:00
|
|
|
node.at_locs[index], [])
|
2015-06-15 16:30:50 +08:00
|
|
|
self.engine.process(diag)
|
2015-07-04 07:23:55 +08:00
|
|
|
|
2015-08-15 21:45:16 +08:00
|
|
|
try:
|
|
|
|
old_function, self.function = self.function, node
|
|
|
|
old_in_loop, self.in_loop = self.in_loop, False
|
|
|
|
old_has_return, self.has_return = self.has_return, False
|
2015-07-04 07:23:55 +08:00
|
|
|
|
2015-08-15 21:45:16 +08:00
|
|
|
self.generic_visit(node)
|
2015-07-04 07:23:55 +08:00
|
|
|
|
2015-08-15 21:45:16 +08:00
|
|
|
# Lack of return statements is not the only case where the return
|
|
|
|
# type cannot be inferred. The other one is infinite (possibly mutual)
|
|
|
|
# recursion. Since Python functions don't have to return a value,
|
|
|
|
# we ignore that one.
|
|
|
|
if not self.has_return:
|
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"function with return type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
node.name_loc),
|
|
|
|
]
|
|
|
|
self._unify(node.return_type, builtins.TNone(),
|
|
|
|
node.name_loc, None, makenotes)
|
|
|
|
finally:
|
|
|
|
self.function = old_function
|
|
|
|
self.in_loop = old_in_loop
|
|
|
|
self.has_return = old_has_return
|
2015-07-04 05:58:48 +08:00
|
|
|
|
2015-06-15 16:30:50 +08:00
|
|
|
signature_type = self._type_from_arguments(node.args, node.return_type)
|
|
|
|
if signature_type:
|
|
|
|
self._unify(node.signature_type, signature_type,
|
|
|
|
node.name_loc, None)
|
|
|
|
|
2016-03-27 04:01:51 +08:00
|
|
|
visit_QuotedFunctionDefT = visit_FunctionDefT
|
|
|
|
|
2015-08-15 21:45:16 +08:00
|
|
|
def visit_ClassDefT(self, node):
|
|
|
|
if any(node.decorator_list):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"decorators are not supported", {},
|
|
|
|
node.at_locs[0], [node.decorator_list[0].loc])
|
|
|
|
self.engine.process(diag)
|
|
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
2015-06-11 11:34:22 +08:00
|
|
|
def visit_Return(self, node):
|
2015-06-13 17:07:46 +08:00
|
|
|
if not self.function:
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"return statement outside of a function", {},
|
|
|
|
node.keyword_loc)
|
|
|
|
self.engine.process(diag)
|
|
|
|
return
|
|
|
|
|
2015-07-04 05:58:48 +08:00
|
|
|
self.has_return = True
|
|
|
|
|
2015-06-13 16:03:33 +08:00
|
|
|
self.generic_visit(node)
|
2015-06-11 11:34:22 +08:00
|
|
|
def makenotes(printer, typea, typeb, loca, locb):
|
|
|
|
return [
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"function with return type {typea}",
|
|
|
|
{"typea": printer.name(typea)},
|
|
|
|
self.function.name_loc),
|
|
|
|
diagnostic.Diagnostic("note",
|
|
|
|
"a statement returning {typeb}",
|
|
|
|
{"typeb": printer.name(typeb)},
|
|
|
|
node.loc)
|
|
|
|
]
|
|
|
|
if node.value is None:
|
2015-06-13 15:29:26 +08:00
|
|
|
self._unify(self.function.return_type, builtins.TNone(),
|
2015-06-11 11:34:22 +08:00
|
|
|
self.function.name_loc, node.loc, makenotes)
|
|
|
|
else:
|
|
|
|
self._unify(self.function.return_type, node.value.type,
|
|
|
|
self.function.name_loc, node.value.loc, makenotes)
|
2015-07-22 07:58:59 +08:00
|
|
|
|
2015-07-25 10:37:37 +08:00
|
|
|
def visit_Raise(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
2015-07-27 17:36:21 +08:00
|
|
|
if node.exc is not None:
|
|
|
|
exc_type = node.exc.type
|
2016-05-10 09:41:40 +08:00
|
|
|
if types.is_exn_constructor(exc_type):
|
|
|
|
pass # short form
|
|
|
|
elif not types.is_var(exc_type) and not builtins.is_exception(exc_type):
|
2015-07-27 17:36:21 +08:00
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"cannot raise a value of type {type}, which is not an exception",
|
|
|
|
{"type": types.TypePrinter().name(exc_type)},
|
2016-05-10 09:41:40 +08:00
|
|
|
node.loc)
|
2015-07-27 17:36:21 +08:00
|
|
|
self.engine.process(diag)
|
2015-07-25 10:37:37 +08:00
|
|
|
|
2015-07-22 07:58:59 +08:00
|
|
|
def visit_Assert(self, node):
|
|
|
|
self.generic_visit(node)
|
|
|
|
self._unify(node.test.type, builtins.TBool(),
|
|
|
|
node.test.loc, None)
|
|
|
|
if node.msg is not None:
|
|
|
|
if not isinstance(node.msg, asttyped.StrT):
|
|
|
|
diag = diagnostic.Diagnostic("error",
|
|
|
|
"assertion message must be a string literal", {},
|
|
|
|
node.msg.loc)
|
|
|
|
self.engine.process(diag)
|