artiq/artiq/compiler/embedding.py
pca006132 4ede58e44b compiler: reduce calls to TypedTreeHasher
We need to check if our inference reached a fixed point. This is checked
using hash of the types in the AST, which is very slow. This patch
avoids computing the hash if we can make sure that the AST is definitely
changed, which is when we parse a new function.

For some simple programs with many functions, this can significantly
reduce the compile time by up to ~30%.
2021-07-07 09:22:16 +08:00

1204 lines
52 KiB
Python

"""
The :class:`Stitcher` class allows to transparently combine compiled
Python code and Python code executed on the host system: it resolves
the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced.
"""
import sys, os, re, linecache, inspect, textwrap, types as pytypes, numpy
from collections import OrderedDict, defaultdict
from pythonparser import ast, algorithm, source, diagnostic, parse_buffer
from pythonparser import lexer as source_lexer, parser as source_parser
from Levenshtein import ratio as similarity, jaro_winkler
from ..language import core as language_core
from . import types, builtins, asttyped, math_fns, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, TypedtreePrinter
from .transforms.asttyped_rewriter import LocalExtractor
class SpecializedFunction:
def __init__(self, instance_type, host_function):
self.instance_type = instance_type
self.host_function = host_function
def __eq__(self, other):
if isinstance(other, tuple):
return (self.instance_type == other[0] or
self.host_function == other[1])
else:
return (self.instance_type == other.instance_type or
self.host_function == other.host_function)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash((self.instance_type, self.host_function))
class EmbeddingMap:
def __init__(self):
self.object_current_key = 0
self.object_forward_map = {}
self.object_reverse_map = {}
self.module_map = {}
self.type_map = {}
self.function_map = {}
# Modules
def store_module(self, module, module_type):
self.module_map[module] = module_type
def retrieve_module(self, module):
return self.module_map[module]
def has_module(self, module):
return module in self.module_map
# Types
def store_type(self, host_type, instance_type, constructor_type):
self._rename_type(instance_type)
self.type_map[host_type] = (instance_type, constructor_type)
def retrieve_type(self, host_type):
return self.type_map[host_type]
def has_type(self, host_type):
return host_type in self.type_map
def _rename_type(self, new_instance_type):
# Generally, user-defined types that have exact same name (which is to say, classes
# defined inside functions) do not pose a problem to the compiler. The two places which
# cannot handle this are:
# 1. {TInstance,TConstructor}.__hash__
# 2. LLVM type names
# Since handling #2 requires renaming on ARTIQ side anyway, it's more straightforward
# to do it once when embedding (since non-embedded code cannot define classes in
# functions). Also, easier to debug.
n = 0
for host_type in self.type_map:
instance_type, constructor_type = self.type_map[host_type]
if instance_type.name == new_instance_type.name:
n += 1
new_instance_type.name = "{}.{}".format(new_instance_type.name, n)
def attribute_count(self):
count = 0
for host_type in self.type_map:
instance_type, constructor_type = self.type_map[host_type]
count += len(instance_type.attributes)
count += len(constructor_type.attributes)
return count
# Functions
def store_function(self, function, ir_function_name):
self.function_map[function] = ir_function_name
def retrieve_function(self, function):
return self.function_map[function]
def specialize_function(self, instance_type, host_function):
return SpecializedFunction(instance_type, host_function)
# Objects
def store_object(self, obj_ref):
obj_id = id(obj_ref)
if obj_id in self.object_reverse_map:
return self.object_reverse_map[obj_id]
self.object_current_key += 1
self.object_forward_map[self.object_current_key] = obj_ref
self.object_reverse_map[obj_id] = self.object_current_key
return self.object_current_key
def retrieve_object(self, obj_key):
return self.object_forward_map[obj_key]
def iter_objects(self):
for obj_id in self.object_forward_map.keys():
obj_ref = self.object_forward_map[obj_id]
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
pytypes.BuiltinFunctionType, pytypes.ModuleType,
SpecializedFunction)):
continue
elif isinstance(obj_ref, type):
_, obj_typ = self.type_map[obj_ref]
else:
obj_typ, _ = self.type_map[type(obj_ref)]
yield obj_id, obj_ref, obj_typ
def has_rpc(self):
return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x),
self.object_forward_map.values()))
class ASTSynthesizer:
def __init__(self, embedding_map, value_map, quote_function=None, expanded_from=None):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.embedding_map = embedding_map
self.value_map = value_map
self.quote_function = quote_function
self.expanded_from = expanded_from
self.diagnostics = []
def finalize(self):
self.source_buffer.source = self.source
return self.source_buffer
def _add(self, fragment):
range_from = len(self.source)
self.source += fragment
range_to = len(self.source)
return source.Range(self.source_buffer, range_from, range_to,
expanded_from=self.expanded_from)
def fast_quote_list(self, value):
elts = [None] * len(value)
is_T = False
if len(value) > 0:
v = value[0]
is_T = True
if isinstance(v, int):
T = int
elif isinstance(v, float):
T = float
elif isinstance(v, numpy.int32):
T = numpy.int32
elif isinstance(v, numpy.int64):
T = numpy.int64
else:
is_T = False
if is_T:
for v in value:
if not isinstance(v, T):
is_T = False
break
if is_T:
is_int = T != float
if T == int:
typ = builtins.TInt()
elif T == float:
typ = builtins.TFloat()
elif T == numpy.int32:
typ = builtins.TInt32()
elif T == numpy.int64:
typ = builtins.TInt64()
else:
assert False
text = [repr(elt) for elt in value]
start = len(self.source)
self.source += ", ".join(text)
if is_int:
for i, (v, t) in enumerate(zip(value, text)):
l = len(t)
elts[i] = asttyped.NumT(
n=int(v), ctx=None, type=typ,
loc=source.Range(
self.source_buffer, start, start + l,
expanded_from=self.expanded_from))
start += l + 2
else:
for i, (v, t) in enumerate(zip(value, text)):
l = len(t)
elts[i] = asttyped.NumT(
n=v, ctx=None, type=typ,
loc=source.Range(
self.source_buffer, start, start + l,
expanded_from=self.expanded_from))
start += l + 2
else:
for index, elt in enumerate(value):
elts[index] = self.quote(elt)
if index < len(value) - 1:
self._add(", ")
return elts
def quote(self, value):
"""Construct an AST fragment equal to `value`."""
if value is None:
typ = builtins.TNone()
return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, (bool, numpy.bool_)):
typ = builtins.TBool()
coerced = bool(value)
return asttyped.NameConstantT(value=coerced, type=typ,
loc=self._add(repr(coerced)))
elif value is float:
typ = builtins.fn_float()
return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("float"))
elif value is numpy.int32:
typ = builtins.fn_int32()
return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("numpy.int32"))
elif value is numpy.int64:
typ = builtins.fn_int64()
return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("numpy.int64"))
elif value is numpy.array:
typ = builtins.fn_array()
return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("numpy.array"))
elif value is numpy.full:
typ = builtins.fn_make_array()
return asttyped.NameConstantT(value=None, type=typ,
loc=self._add("numpy.full"))
elif isinstance(value, (int, float)):
if isinstance(value, int):
typ = builtins.TInt()
elif isinstance(value, float):
typ = builtins.TFloat()
return asttyped.NumT(n=value, ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, numpy.int32):
typ = builtins.TInt32()
return asttyped.NumT(n=int(value), ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, numpy.int64):
typ = builtins.TInt64()
return asttyped.NumT(n=int(value), ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, str):
return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(),
loc=self._add(repr(value)))
elif isinstance(value, bytes):
return asttyped.StrT(s=value, ctx=None, type=builtins.TBytes(),
loc=self._add(repr(value)))
elif isinstance(value, bytearray):
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')
loc = quote_loc.join(unquote_loc)
return asttyped.QuoteT(value=value, type=builtins.TByteArray(), loc=loc)
elif isinstance(value, list):
begin_loc = self._add("[")
elts = self.fast_quote_list(value)
end_loc = self._add("]")
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
elif isinstance(value, tuple):
begin_loc = self._add("(")
elts = self.fast_quote_list(value)
end_loc = self._add(")")
return asttyped.TupleT(elts=elts, ctx=None,
type=types.TTuple([e.type for e in elts]),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
elif isinstance(value, numpy.ndarray):
return self.call(numpy.array, [list(value)], {})
elif inspect.isfunction(value) or inspect.ismethod(value) or \
isinstance(value, pytypes.BuiltinFunctionType) or \
isinstance(value, SpecializedFunction) or \
isinstance(value, numpy.ufunc):
if inspect.ismethod(value):
quoted_self = self.quote(value.__self__)
function_type = self.quote_function(value.__func__, self.expanded_from)
method_type = types.TMethod(quoted_self.type, function_type)
dot_loc = self._add('.')
name_loc = self._add(value.__func__.__name__)
loc = quoted_self.loc.join(name_loc)
return asttyped.QuoteT(value=value, type=method_type,
self_loc=quoted_self.loc, loc=loc)
else: # function
function_type = self.quote_function(value, self.expanded_from)
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')
loc = quote_loc.join(unquote_loc)
return asttyped.QuoteT(value=value, type=function_type, loc=loc)
elif isinstance(value, pytypes.ModuleType):
if self.embedding_map.has_module(value):
module_type = self.embedding_map.retrieve_module(value)
else:
module_type = types.TModule(value.__name__, OrderedDict())
module_type.attributes['__objectid__'] = builtins.TInt32()
self.embedding_map.store_module(value, module_type)
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')
loc = quote_loc.join(unquote_loc)
self.value_map[module_type].append((value, loc))
return asttyped.QuoteT(value=value, type=module_type, loc=loc)
else:
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')
loc = quote_loc.join(unquote_loc)
if isinstance(value, type):
typ = value
else:
typ = type(value)
if self.embedding_map.has_type(typ):
instance_type, constructor_type = self.embedding_map.retrieve_type(typ)
if hasattr(value, 'kernel_invariants') and \
value.kernel_invariants != instance_type.constant_attributes:
attr_diff = value.kernel_invariants.difference(
instance_type.constant_attributes)
if len(attr_diff) > 0:
diag = diagnostic.Diagnostic("warning",
"object {value} of type {typ} declares attribute(s) {attrs} as "
"kernel invariant, but other objects of the same type do not; "
"the invariant annotation on this object will be ignored",
{"value": repr(value),
"typ": types.TypePrinter().name(instance_type, max_depth=0),
"attrs": ", ".join(["'{}'".format(attr) for attr in attr_diff])},
loc)
self.diagnostics.append(diag)
attr_diff = instance_type.constant_attributes.difference(
value.kernel_invariants)
if len(attr_diff) > 0:
diag = diagnostic.Diagnostic("warning",
"object {value} of type {typ} does not declare attribute(s) {attrs} as "
"kernel invariant, but other objects of the same type do; "
"the invariant annotation on other objects will be ignored",
{"value": repr(value),
"typ": types.TypePrinter().name(instance_type, max_depth=0),
"attrs": ", ".join(["'{}'".format(attr) for attr in attr_diff])},
loc)
self.diagnostics.append(diag)
value.kernel_invariants = value.kernel_invariants.intersection(
instance_type.constant_attributes)
else:
if issubclass(typ, BaseException):
if hasattr(typ, 'artiq_builtin'):
exception_id = 0
else:
exception_id = self.embedding_map.store_object(typ)
instance_type = builtins.TException("{}.{}".format(typ.__module__,
typ.__qualname__),
id=exception_id)
constructor_type = types.TExceptionConstructor(instance_type)
else:
instance_type = types.TInstance("{}.{}".format(typ.__module__, typ.__qualname__),
OrderedDict())
instance_type.attributes['__objectid__'] = builtins.TInt32()
constructor_type = types.TConstructor(instance_type)
constructor_type.attributes['__objectid__'] = builtins.TInt32()
instance_type.constructor = constructor_type
self.embedding_map.store_type(typ, instance_type, constructor_type)
if hasattr(value, 'kernel_invariants'):
assert isinstance(value.kernel_invariants, set)
instance_type.constant_attributes = value.kernel_invariants
if isinstance(value, type):
self.value_map[constructor_type].append((value, loc))
return asttyped.QuoteT(value=value, type=constructor_type,
loc=loc)
else:
self.value_map[instance_type].append((value, loc))
return asttyped.QuoteT(value=value, type=instance_type,
loc=loc)
def call(self, callee, args, kwargs, callback=None):
"""
Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments.
"""
if callback is not None:
callback_node = self.quote(callback)
cb_begin_loc = self._add("(")
callee_node = self.quote(callee)
arg_nodes = []
kwarg_nodes = []
kwarg_locs = []
begin_loc = self._add("(")
for index, arg in enumerate(args):
arg_nodes.append(self.quote(arg))
if index < len(args) - 1:
self._add(", ")
if any(args) and any(kwargs):
self._add(", ")
for index, kw in enumerate(kwargs):
arg_loc = self._add(kw)
equals_loc = self._add("=")
kwarg_locs.append((arg_loc, equals_loc))
kwarg_nodes.append(self.quote(kwargs[kw]))
if index < len(kwargs) - 1:
self._add(", ")
end_loc = self._add(")")
if callback is not None:
cb_end_loc = self._add(")")
node = asttyped.CallT(
func=callee_node,
args=arg_nodes,
keywords=[ast.keyword(arg=kw, value=value,
arg_loc=arg_loc, equals_loc=equals_loc,
loc=arg_loc.join(value.loc))
for kw, value, (arg_loc, equals_loc)
in zip(kwargs, kwarg_nodes, kwarg_locs)],
starargs=None, kwargs=None,
type=types.TVar(), iodelay=None, arg_exprs={},
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=callee_node.loc.join(end_loc))
if callback is not None:
node = asttyped.CallT(
func=callback_node,
args=[node], keywords=[], starargs=None, kwargs=None,
type=builtins.TNone(), iodelay=None, arg_exprs={},
begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None,
loc=callback_node.loc.join(cb_end_loc))
return node
def suggest_identifier(id, names):
sorted_names = sorted(names, key=lambda other: jaro_winkler(id, other), reverse=True)
if len(sorted_names) > 0:
if jaro_winkler(id, sorted_names[0]) > 0.0 and similarity(id, sorted_names[0]) > 0.5:
return sorted_names[0]
class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote):
super().__init__(engine, prelude)
self.globals = globals
self.env_stack.append(self.globals)
self.host_environment = host_environment
self.quote = quote
def visit_arg(self, node):
typ = self._find_name(node.arg, node.loc)
# ignore annotations; these are handled in _quote_function
return asttyped.argT(type=typ,
arg=node.arg, annotation=None,
arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc)
def visit_quoted_function(self, node, function):
extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine)
extractor.visit(node)
# We quote the defaults so they end up in the global data in LLVM IR.
# This way there is no "life before main", i.e. they do not have to be
# constructed before the main translated call executes; but the Python
# semantics is kept.
defaults = function.__defaults__ or ()
quoted_defaults = []
for default, default_node in zip(defaults, node.args.defaults):
quoted_defaults.append(self.quote(default, default_node.loc))
node.args.defaults = quoted_defaults
node = asttyped.QuotedFunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
signature_type=types.TVar(), return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns,
body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,
loc=node.loc)
try:
self.env_stack.append(node.typing_env)
return self.generic_visit(node)
finally:
self.env_stack.pop()
def visit_Name(self, node):
typ = super()._try_find_name(node.id)
if typ is not None:
# Value from device environment.
return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx,
loc=node.loc)
else:
# Try to find this value in the host environment and quote it.
if node.id == "print":
return self.quote(print, node.loc)
elif node.id in self.host_environment:
return self.quote(self.host_environment[node.id], node.loc)
else:
names = set()
names.update(self.host_environment.keys())
for typing_env in reversed(self.env_stack):
names.update(typing_env.keys())
suggestion = suggest_identifier(node.id, names)
if suggestion is not None:
diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything; did you mean '{suggestion}'?",
{"name": node.id, "suggestion": suggestion},
node.loc)
self.engine.process(diag)
else:
diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything", {"name": node.id},
node.loc)
self.engine.process(diag)
class StitchingInferencer(Inferencer):
def __init__(self, engine, value_map, quote):
super().__init__(engine)
self.value_map = value_map
self.quote = quote
self.attr_type_cache = {}
def _compute_attr_type(self, object_value, object_type, object_loc, attr_name, loc):
if not hasattr(object_value, attr_name):
if attr_name.startswith('_'):
names = set(filter(lambda name: not name.startswith('_'),
dir(object_value)))
else:
names = set(dir(object_value))
suggestion = suggest_identifier(attr_name, names)
note = diagnostic.Diagnostic("note",
"attribute accessed here", {},
loc)
if suggestion is not None:
diag = diagnostic.Diagnostic("error",
"host object does not have an attribute '{attr}'; "
"did you mean '{suggestion}'?",
{"attr": attr_name, "suggestion": suggestion},
object_loc, notes=[note])
else:
diag = diagnostic.Diagnostic("error",
"host object does not have an attribute '{attr}'",
{"attr": attr_name},
object_loc, notes=[note])
self.engine.process(diag)
return
# Figure out the ARTIQ type of the value of the attribute.
# We do this by quoting it, as if to serialize. This has some
# overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors.
attr_value = getattr(object_value, attr_name)
if (inspect.ismethod(attr_value) and
types.is_instance(object_type) and
# Check that the method is indeed defined on the class,
# and not just this instance. The check is written in
# the inverted form and not as hasattr(type(attr_value))
# since the method may as well be defined on a superclass.
attr_name not in object_value.__dict__):
# In cases like:
# class c:
# @kernel
# def f(self): pass
# we want f to be defined on the class, not on the instance.
attributes = object_type.constructor.attributes
attr_value = SpecializedFunction(object_type, attr_value.__func__)
else:
attributes = object_type.attributes
attr_value_type = None
if isinstance(attr_value, list):
# Fast path for lists of scalars.
IS_FLOAT = 1
IS_INT32 = 2
IS_INT64 = 4
state = 0
for elt in attr_value:
if elt.__class__ == float:
state |= IS_FLOAT
elif elt.__class__ == int:
if -2**31 < elt < 2**31-1:
state |= IS_INT32
elif -2**63 < elt < 2**63-1:
state |= IS_INT64
else:
state = -1
break
else:
state = -1
if state == IS_FLOAT:
attr_value_type = builtins.TList(builtins.TFloat())
elif state == IS_INT32:
attr_value_type = builtins.TList(builtins.TInt32())
elif state == IS_INT64:
attr_value_type = builtins.TList(builtins.TInt64())
if attr_value_type is None:
note = diagnostic.Diagnostic("note",
"while inferring a type for an attribute '{attr}' of a host object",
{"attr": attr_name},
loc)
with self.engine.context(note):
# Slow path. We don't know what exactly is the attribute value,
# so we quote it only for the error message that may possibly result.
ast = self.quote(attr_value, object_loc.expanded_from)
Inferencer(engine=self.engine).visit(ast)
IntMonomorphizer(engine=self.engine).visit(ast)
attr_value_type = ast.type
return attributes, attr_value_type
def _unify_attribute(self, result_type, value_node, attr_name, attr_loc, loc):
# The inferencer can only observe types, not values; however,
# when we work with host objects, we have to get the values
# somewhere, since host interpreter does not have types.
# Since we have categorized every host object we quoted according to
# its type, we now interrogate every host object we have to ensure
# that we can successfully serialize the value of the attribute we
# are now adding at the code generation stage.
object_type = value_node.type.find()
for object_value, object_loc in self.value_map[object_type]:
attr_type_key = (id(object_value), attr_name)
try:
attributes, attr_value_type = self.attr_type_cache[attr_type_key]
except KeyError:
attributes, attr_value_type = \
self._compute_attr_type(object_value, object_type, object_loc, attr_name, loc)
self.attr_type_cache[attr_type_key] = attributes, attr_value_type
if attr_name not in attributes:
# We just figured out what the type should be. Add it.
attributes[attr_name] = attr_value_type
else:
# Does this conflict with an earlier guess?
try:
attributes[attr_name].unify(attr_value_type)
except types.UnificationError as e:
printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error",
"host object has an attribute '{attr}' of type {typea}, which is"
" different from previously inferred type {typeb} for the same attribute",
{"typea": printer.name(attr_value_type),
"typeb": printer.name(attributes[attr_name]),
"attr": attr_name},
object_loc)
self.engine.process(diag)
super()._unify_attribute(result_type, value_node, attr_name, attr_loc, loc)
def visit_QuoteT(self, node):
if inspect.ismethod(node.value):
if types.is_rpc(types.get_method_function(node.type)):
return
self._unify_method_self(method_type=node.type,
attr_name=node.value.__func__.__name__,
attr_loc=None,
loc=node.loc,
self_loc=node.self_loc)
class TypedtreeHasher(algorithm.Visitor):
def generic_visit(self, node):
def freeze(obj):
if isinstance(obj, ast.AST):
return self.visit(obj)
elif isinstance(obj, list):
return hash(tuple(freeze(elem) for elem in obj))
elif isinstance(obj, types.Type):
return hash(obj.find())
else:
# We don't care; only types change during inference.
pass
fields = node._fields
if hasattr(node, '_types'):
fields = fields + node._types
return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields))
class Stitcher:
def __init__(self, core, dmgr, engine=None, print_as_rpc=True):
self.core = core
self.dmgr = dmgr
if engine is None:
self.engine = diagnostic.Engine(all_errors_are_fatal=True)
else:
self.engine = engine
self.name = ""
self.typedtree = []
self.inject_at = 0
self.globals = {}
# We don't want some things from the prelude as they are provided in
# the host Python namespace and gain special meaning when quoted.
self.prelude = prelude.globals()
if print_as_rpc:
self.prelude.pop("print")
self.prelude.pop("array")
self.functions = {}
self.embedding_map = EmbeddingMap()
self.value_map = defaultdict(lambda: [])
self.definitely_changed = False
def stitch_call(self, function, args, kwargs, callback=None):
# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function))
call_node = synthesizer.call(function, args, kwargs, callback)
synthesizer.finalize()
self.typedtree.append(call_node)
def finalize(self):
inferencer = StitchingInferencer(engine=self.engine,
value_map=self.value_map,
quote=self._quote)
typedtree_hasher = TypedtreeHasher()
# Iterate inference to fixed point.
old_typedtree_hash = None
old_attr_count = None
while True:
inferencer.visit(self.typedtree)
if self.definitely_changed:
changed = True
self.definitely_changed = False
else:
typedtree_hash = typedtree_hasher.visit(self.typedtree)
attr_count = self.embedding_map.attribute_count()
changed = old_attr_count != attr_count or \
old_typedtree_hash != typedtree_hash
old_typedtree_hash = typedtree_hash
old_attr_count = attr_count
if not changed:
break
# After we've discovered every referenced attribute, check if any kernel_invariant
# specifications refers to ones we didn't encounter.
for host_type in self.embedding_map.type_map:
instance_type, constructor_type = self.embedding_map.type_map[host_type]
if not hasattr(instance_type, "constant_attributes"):
# Exceptions lack user-definable attributes.
continue
for attribute in instance_type.constant_attributes:
if attribute in instance_type.attributes:
# Fast path; if the ARTIQ Python type has the attribute, then every observed
# value is guaranteed to have it too.
continue
for value, loc in self.value_map[instance_type]:
if hasattr(value, attribute):
continue
diag = diagnostic.Diagnostic("warning",
"object {value} of type {typ} declares attribute '{attr}' as "
"kernel invariant, but the instance referenced here does not "
"have this attribute",
{"value": repr(value),
"typ": types.TypePrinter().name(instance_type, max_depth=0),
"attr": attribute},
loc)
self.engine.process(diag)
# After we have found all functions, synthesize a module to hold them.
source_buffer = source.Buffer("", "<synthesized>")
self.typedtree = asttyped.ModuleT(
typing_env=self.globals, globals_in_scope=set(),
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))
def _inject(self, node):
self.typedtree.insert(self.inject_at, node)
self.inject_at += 1
def _synthesizer(self, expanded_from=None):
return ASTSynthesizer(expanded_from=expanded_from,
embedding_map=self.embedding_map,
value_map=self.value_map,
quote_function=self._quote_function)
def _function_loc(self, function):
if isinstance(function, SpecializedFunction):
function = function.host_function
if hasattr(function, 'artiq_embedded') and function.artiq_embedded.function:
function = function.artiq_embedded.function
if isinstance(function, str):
return source.Range(source.Buffer(function, "<string>"), 0, 0)
filename = function.__code__.co_filename
line = function.__code__.co_firstlineno
name = function.__code__.co_name
source_line = linecache.getline(filename, line).lstrip()
while source_line.startswith("@") or source_line == "":
line += 1
source_line = linecache.getline(filename, line).lstrip()
if "<lambda>" in function.__qualname__:
column = 0 # can't get column of lambda
else:
column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line)
return source.Range(source_buffer, column, column)
def _call_site_note(self, call_loc, fn_kind):
if call_loc:
if fn_kind == 'syscall':
return [diagnostic.Diagnostic("note",
"in system call here", {},
call_loc)]
elif fn_kind == 'rpc':
return [diagnostic.Diagnostic("note",
"in function called remotely here", {},
call_loc)]
elif fn_kind == 'kernel':
return [diagnostic.Diagnostic("note",
"in kernel function here", {},
call_loc)]
else:
assert False
else:
return []
def _type_of_param(self, function, loc, param, fn_kind):
if param.annotation is not inspect.Parameter.empty:
# Type specified explicitly.
return self._extract_annot(function, param.annotation,
"argument '{}'".format(param.name), loc,
fn_kind)
elif fn_kind == 'syscall':
# Syscalls must be entirely annotated.
diag = diagnostic.Diagnostic("error",
"system call argument '{argument}' must have a type annotation",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind))
self.engine.process(diag)
elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty:
notes = []
notes.append(diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" unannotated optional argument '{argument}' from its default value",
{"argument": param.name},
self._function_loc(function)))
if loc is not None:
notes.append(self._call_site_note(loc, fn_kind))
with self.engine.context(*notes):
# Try and infer the type from the default value.
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
ast = self._quote(param.default, None)
Inferencer(engine=self.engine).visit(ast)
IntMonomorphizer(engine=self.engine).visit(ast)
return ast.type
else:
# Let the rest of the program decide.
return types.TVar()
def _quote_embedded_function(self, function, flags):
# we are now parsing new functions... definitely changed the type
self.definitely_changed = True
if isinstance(function, SpecializedFunction):
host_function = function.host_function
else:
host_function = function
if not hasattr(host_function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(host_function)))
# Extract function source.
embedded_function = host_function.artiq_embedded.function
if isinstance(embedded_function, str):
# This is a function to be eval'd from the given source code in string form.
# Mangle the host function's id() into the fully qualified name to make sure
# there are no collisions.
source_code = embedded_function
embedded_function = host_function
filename = "<string>"
module_name = "__eval_{}".format(id(host_function))
first_line = 1
else:
source_code = inspect.getsource(embedded_function)
filename = embedded_function.__code__.co_filename
module_name = embedded_function.__globals__['__name__']
first_line = embedded_function.__code__.co_firstlineno
# Extract function annotation.
signature = inspect.signature(embedded_function)
loc = self._function_loc(embedded_function)
arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind == inspect.Parameter.VAR_POSITIONAL or \
param.kind == inspect.Parameter.VAR_KEYWORD:
diag = diagnostic.Diagnostic("error",
"variadic arguments are not supported; '{argument}' is variadic",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind='kernel'))
self.engine.process(diag)
arg_type = self._type_of_param(function, loc, param, fn_kind='kernel')
if param.default is inspect.Parameter.empty:
arg_types[param.name] = arg_type
else:
optarg_types[param.name] = arg_type
if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc, fn_kind='kernel')
else:
ret_type = types.TVar()
# Extract function environment.
host_environment = dict()
host_environment.update(embedded_function.__globals__)
cells = embedded_function.__closure__
cell_names = embedded_function.__code__.co_freevars
host_environment.update({var: cells[index] for index, var in enumerate(cell_names)})
# Find out how indented we are.
initial_whitespace = re.search(r"^\s*", source_code).group(0)
initial_indent = len(initial_whitespace.expandtabs())
# Parse.
source_buffer = source.Buffer(source_code, filename, first_line)
lexer = source_lexer.Lexer(source_buffer, version=sys.version_info[0:2],
diagnostic_engine=self.engine)
lexer.indent = [(initial_indent,
source.Range(source_buffer, 0, len(initial_whitespace)),
initial_whitespace)]
parser = source_parser.Parser(lexer, version=sys.version_info[0:2],
diagnostic_engine=self.engine)
function_node = parser.file_input().body[0]
# Mangle the name, since we put everything into a single module.
full_function_name = "{}.{}".format(module_name, host_function.__qualname__)
if isinstance(function, SpecializedFunction):
instance_type = function.instance_type
function_node.name = "_Z{}{}I{}{}Ezz".format(len(full_function_name), full_function_name,
len(instance_type.name), instance_type.name)
else:
function_node.name = "_Z{}{}zz".format(len(full_function_name), full_function_name)
# Record the function in the function map so that LLVM IR generator
# can handle quoting it.
self.embedding_map.store_function(function, function_node.name)
# Fill in the function type before typing it to handle recursive
# invocations.
self.functions[function] = types.TFunction(arg_types, optarg_types, ret_type)
# Rewrite into typed form.
asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment,
quote=self._quote)
function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function)
function_node.flags = flags
# Add it into our typedtree so that it gets inferenced and codegen'd.
self._inject(function_node)
# Tie the typing knot.
self.functions[function].unify(function_node.signature_type)
return function_node
def _extract_annot(self, function, annot, kind, call_loc, fn_kind):
if annot is None:
annot = builtins.TNone()
if isinstance(function, SpecializedFunction):
host_function = function.host_function
else:
host_function = function
if hasattr(host_function, 'artiq_embedded'):
embedded_function = host_function.artiq_embedded.function
else:
embedded_function = host_function
if isinstance(embedded_function, str):
embedded_function = host_function
if isinstance(annot, str):
try:
annot = eval(annot, embedded_function.__globals__)
except Exception:
diag = diagnostic.Diagnostic(
"error",
"type annotation for {kind}, {annot}, cannot be evaluated",
{"kind": kind, "annot": repr(annot)},
self._function_loc(function),
notes=self._call_site_note(call_loc, fn_kind))
self.engine.process(diag)
if not isinstance(annot, types.Type):
diag = diagnostic.Diagnostic("error",
"type annotation for {kind}, '{annot}', is not an ARTIQ type",
{"kind": kind, "annot": repr(annot)},
self._function_loc(function),
notes=self._call_site_note(call_loc, fn_kind))
self.engine.process(diag)
return types.TVar()
else:
return annot
def _quote_syscall(self, function, loc):
signature = inspect.signature(function)
arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
diag = diagnostic.Diagnostic("error",
"system calls must only use positional arguments; '{argument}' isn't",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind='syscall'))
self.engine.process(diag)
if param.default is inspect.Parameter.empty:
arg_types[param.name] = self._type_of_param(function, loc, param,
fn_kind='syscall')
else:
diag = diagnostic.Diagnostic("error",
"system call argument '{argument}' must not have a default value",
{"argument": param.name},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind='syscall'))
self.engine.process(diag)
if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(function, signature.return_annotation,
"return type", loc, fn_kind='syscall')
else:
diag = diagnostic.Diagnostic("error",
"system call must have a return type annotation", {},
self._function_loc(function),
notes=self._call_site_note(loc, fn_kind='syscall'))
self.engine.process(diag)
ret_type = types.TVar()
function_type = types.TExternalFunction(arg_types, ret_type,
name=function.artiq_embedded.syscall,
flags=function.artiq_embedded.flags)
self.functions[function] = function_type
return function_type
def _quote_rpc(self, function, loc):
if isinstance(function, SpecializedFunction):
host_function = function.host_function
else:
host_function = function
ret_type = builtins.TNone()
if isinstance(host_function, pytypes.BuiltinFunctionType):
pass
elif (isinstance(host_function, pytypes.FunctionType) or \
isinstance(host_function, pytypes.MethodType)):
if isinstance(host_function, pytypes.FunctionType):
signature = inspect.signature(host_function)
else:
# inspect bug?
signature = inspect.signature(host_function.__func__)
if signature.return_annotation is not inspect.Signature.empty:
ret_type = self._extract_annot(host_function, signature.return_annotation,
"return type", loc, fn_kind='rpc')
else:
assert False
is_async = False
if hasattr(host_function, "artiq_embedded") and \
"async" in host_function.artiq_embedded.flags:
is_async = True
if not builtins.is_none(ret_type) and is_async:
note = diagnostic.Diagnostic("note",
"function called here", {},
loc)
diag = diagnostic.Diagnostic("fatal",
"functions that return a value cannot be defined as async RPCs", {},
self._function_loc(host_function.artiq_embedded.function),
notes=[note])
self.engine.process(diag)
function_type = types.TRPC(ret_type,
service=self.embedding_map.store_object(host_function),
is_async=is_async)
self.functions[function] = function_type
return function_type
def _quote_function(self, function, loc):
if isinstance(function, SpecializedFunction):
host_function = function.host_function
else:
host_function = function
if function in self.functions:
return self.functions[function]
math_type = math_fns.match(function)
if math_type is not None:
self.functions[function] = math_type
elif not hasattr(host_function, "artiq_embedded") or \
(host_function.artiq_embedded.core_name is None and
host_function.artiq_embedded.portable is False and
host_function.artiq_embedded.syscall is None and
host_function.artiq_embedded.forbidden is False):
self._quote_rpc(function, loc)
elif host_function.artiq_embedded.function is not None:
if host_function.__name__ == "<lambda>":
note = diagnostic.Diagnostic("note",
"lambda created here", {},
self._function_loc(host_function.artiq_embedded.function))
diag = diagnostic.Diagnostic("fatal",
"lambdas cannot be used as kernel functions", {},
loc,
notes=[note])
self.engine.process(diag)
core_name = host_function.artiq_embedded.core_name
if core_name is not None and self.dmgr.get(core_name) != self.core:
note = diagnostic.Diagnostic("note",
"called from this function", {},
loc)
diag = diagnostic.Diagnostic("fatal",
"this function runs on a different core device '{name}'",
{"name": host_function.artiq_embedded.core_name},
self._function_loc(host_function.artiq_embedded.function),
notes=[note])
self.engine.process(diag)
self._quote_embedded_function(function,
flags=host_function.artiq_embedded.flags)
elif host_function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call.
self._quote_syscall(function, loc)
elif host_function.artiq_embedded.forbidden is not None:
diag = diagnostic.Diagnostic("fatal",
"this function cannot be called as an RPC", {},
self._function_loc(host_function),
notes=self._call_site_note(loc, fn_kind='rpc'))
self.engine.process(diag)
else:
assert False
return self.functions[function]
def _quote(self, value, loc):
synthesizer = self._synthesizer(loc)
node = synthesizer.quote(value)
synthesizer.finalize()
if len(synthesizer.diagnostics) > 0:
for warning in synthesizer.diagnostics:
self.engine.process(warning)
return node