diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index c3151b54c..d74253838 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -1,17 +1,357 @@ -from collections import namedtuple -from fractions import Fraction import inspect import textwrap import ast +import types import builtins -from copy import deepcopy +from fractions import Fraction +from collections import OrderedDict -from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable from artiq.language import core as core_language from artiq.language import units +from artiq.transforms.tools import value_to_ast, NotASTRepresentable -class _HostObjectMapper: +def new_mangled_name(in_use_names, name): + mangled_name = name + i = 2 + while mangled_name in in_use_names: + mangled_name = name + str(i) + i += 1 + in_use_names.add(mangled_name) + return mangled_name + + +class MangledName: + def __init__(self, s): + self.s = s + + +class AttributeInfo: + def __init__(self, obj, mangled_name, read_write): + self.obj = obj + self.mangled_name = mangled_name + self.read_write = read_write + + +embeddable_funcs = ( + core_language.delay, core_language.at, core_language.now, + core_language.time_to_cycles, core_language.cycles_to_time, + core_language.syscall, + range, bool, int, float, round, + core_language.int64, core_language.round64, core_language.array, + Fraction, units.Quantity, core_language.EncodedException +) + + +def is_embeddable(func): + for ef in embeddable_funcs: + if func is ef: + return True + return False + + +def is_inlinable(core, func): + if hasattr(func, "k_function_info"): + if func.k_function_info.core_name == "": + return True # portable function + if getattr(func.__self__, func.k_function_info.core_name) is core: + return True # kernel function for the same core device + return False + + +class GlobalNamespace: + def __init__(self, func): + self.func_gd = inspect.getmodule(func).__dict__ + + def __getitem__(self, item): + try: + return self.func_gd[item] + except KeyError: + return getattr(builtins, item) + + +def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers, + func, args): + global_namespace = GlobalNamespace(func) + func_tr = Function(core, + global_namespace, attribute_namespace, in_use_names, + retval_name, mappers) + func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0] + + param_init = [] + # initialize arguments + for arg_ast, arg_value in zip(func_def.args.args, args): + if isinstance(arg_value, ast.AST): + value = arg_value + else: + try: + value = ast.copy_location(value_to_ast(arg_value), func_def) + except NotASTRepresentable: + value = None + if value is None: + # static object + func_tr.local_namespace[arg_ast.arg] = arg_value + else: + # set parameter value with "name = value" + # assignment at beginning of function + new_name = new_mangled_name(in_use_names, arg_ast.arg) + func_tr.local_namespace[arg_ast.arg] = MangledName(new_name) + target = ast.copy_location(ast.Name(new_name, ast.Store()), + func_def) + assign = ast.copy_location(ast.Assign([target], value), + func_def) + param_init.append(assign) + + func_def = func_tr.code_visit(func_def) + func_def.body[0:0] = param_init + return func_def + + +class Function: + def __init__(self, core, + global_namespace, attribute_namespace, in_use_names, + retval_name, mappers): + # The core device on which this function is executing. + self.core = core + + # Local and global namespaces: + # original name -> MangledName or static object + self.local_namespace = dict() + self.global_namespace = global_namespace + + # (id(static object), attribute) -> AttributeInfo + self.attribute_namespace = attribute_namespace + + # All names currently in use, in the namespace of the combined + # function. + # When creating a name for a new object, check that it is not + # already in this set. + self.in_use_names = in_use_names + + # Name of the variable to store the return value to, or None + # to keep the return statement. + self.retval_name = retval_name + + # Host object mappers, for RPC and exception numbers + self.mappers = mappers + + self._insertion_point = None + + # This is ast.NodeVisitor/NodeTransformer from CPython, modified + # to add code_ prefix. + def code_visit(self, node): + method = "code_visit_" + node.__class__.__name__ + visitor = getattr(self, method, self.code_generic_visit) + return visitor(node) + + # This is ast.NodeTransformer.generic_visit from CPython, modified + # to update self._insertion_point. + def code_generic_visit(self, node): + for field, old_value in ast.iter_fields(node): + old_value = getattr(node, field, None) + if isinstance(old_value, list): + prev_insertion_point = self._insertion_point + new_values = [] + if field in ("body", "orelse", "finalbody"): + self._insertion_point = new_values + for value in old_value: + if isinstance(value, ast.AST): + value = self.code_visit(value) + if value is None: + continue + elif not isinstance(value, ast.AST): + new_values.extend(value) + continue + new_values.append(value) + old_value[:] = new_values + self._insertion_point = prev_insertion_point + elif isinstance(old_value, ast.AST): + new_node = self.code_visit(old_value) + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + return node + + def code_visit_Name(self, node): + if isinstance(node.ctx, ast.Store): + if (node.id in self.local_namespace + and isinstance(self.local_namespace[node.id], + MangledName)): + new_name = self.local_namespace[node.id].s + else: + new_name = new_mangled_name(self.in_use_names, node.id) + self.local_namespace[node.id] = MangledName(new_name) + node.id = new_name + return node + else: + try: + obj = self.local_namespace[node.id] + except KeyError: + try: + obj = self.global_namespace[node.id] + except KeyError: + raise NameError("name '{}' is not defined".format(node.id)) + if isinstance(obj, MangledName): + node.id = obj.s + return node + else: + try: + return value_to_ast(obj) + except NotASTRepresentable: + raise NotImplementedError( + "Static object cannot be used here") + + def code_visit_Attribute(self, node): + # There are two cases of attributes: + # 1. static object attributes, e.g. self.foo + # 2. dynamic expression attributes, e.g. + # (Fraction(1, 2) + x).numerator + # Static object resolution has no side effects so we try it first. + try: + obj = self.static_visit(node.value) + except: + self.code_generic_visit(node) + return node + else: + key = (id(obj), node.attr) + try: + attr_info = self.attribute_namespace[key] + except KeyError: + new_name = new_mangled_name(self.in_use_names, node.attr) + attr_info = AttributeInfo(obj, new_name, False) + self.attribute_namespace[key] = attr_info + if isinstance(node.ctx, ast.Store): + attr_info.read_write = True + return ast.copy_location( + ast.Name(attr_info.mangled_name, node.ctx), + node) + + def code_visit_Call(self, node): + func = self.static_visit(node.func) + node.args = [self.code_visit(arg) for arg in node.args] + + if is_embeddable(func): + node.func = ast.copy_location( + ast.Name(func.__name__, ast.Load()), + node) + return node + elif is_inlinable(self.core, func): + retval_name = func.k_function_info.k_function.__name__ + "_return" + retval_name_m = new_mangled_name(self.in_use_names, retval_name) + args = [func.__self__] + node.args + inlined = get_inline(self.core, + self.attribute_namespace, self.in_use_names, + retval_name_m, self.mappers, + func.k_function_info.k_function, args) + seq = ast.copy_location( + ast.With( + items=[ast.withitem(context_expr=ast.Name(id="sequential", + ctx=ast.Load()), + optional_vars=None)], + body=inlined.body), + node) + self._insertion_point.append(seq) + return ast.copy_location(ast.Name(retval_name_m, ast.Load()), + node) + else: + arg1 = ast.copy_location(ast.Str("rpc"), node) + arg2 = ast.copy_location( + value_to_ast(self.mappers.rpc.encode(func)), node) + node.args[0:0] = [arg1, arg2] + node.func = ast.copy_location( + ast.Name("syscall", ast.Load()), node) + return node + + def code_visit_Return(self, node): + self.code_generic_visit(node) + if self.retval_name is None: + return node + else: + return ast.copy_location( + ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())], + value=node.value), + node) + + def code_visit_Expr(self, node): + if isinstance(node.value, ast.Str): + # Strip docstrings. This also removes strings appearing in the + # middle of the code, but they are nops. + return None + self.code_generic_visit(node) + if isinstance(node.value, ast.Name): + # Remove Expr nodes that contain only a name, likely due to + # function call inlining. Such nodes that were originally in the + # code are also removed, but this does not affect the semantics of + # the code as they are nops. + return None + else: + return node + + def encode_exception(self, e): + exception_class = self.static_visit(e) + if not inspect.isclass(exception_class): + raise NotImplementedError("Exception type must be a class") + if issubclass(exception_class, core_language.RuntimeException): + exception_id = exception_class.eid + else: + exception_id = self.mappers.exception.encode(exception_class) + return ast.copy_location( + ast.Call(func=ast.Name("EncodedException", ast.Load()), + args=[value_to_ast(exception_id)], + keywords=[], starargs=None, kwargs=None), + e) + + def code_visit_Raise(self, node): + if node.cause is not None: + raise NotImplementedError("Exception causes are not supported") + if node.exc is not None: + node.exc = self.encode_exception(node.exc) + return node + + def code_visit_ExceptHandler(self, node): + if node.name is not None: + raise NotImplementedError("'as target' is not supported") + if node.type is not None: + if isinstance(node.type, ast.Tuple): + node.type.elts = [self.encode_exception(e) + for e in node.type.elts] + else: + node.type = self.encode_exception(node.type) + self.code_generic_visit(node) + return node + + def code_visit_FunctionDef(self, node): + node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], + kw_defaults=[], kwarg=None, defaults=[]) + node.decorator_list = [] + self.code_generic_visit(node) + return node + + def static_visit(self, node): + method = "static_visit_" + node.__class__.__name__ + visitor = getattr(self, method) + return visitor(node) + + def static_visit_Name(self, node): + try: + obj = self.local_namespace[node.id] + except KeyError: + try: + obj = self.global_namespace[node.id] + except KeyError: + raise NameError("name '{}' is not defined".format(node.id)) + if isinstance(obj, MangledName): + raise NotImplementedError( + "Only a static object can be used here") + return obj + + def static_visit_Attribute(self, node): + value = self.static_visit(node.value) + return getattr(value, node.attr) + + +class HostObjectMapper: def __init__(self, first_encoding=0): self._next_encoding = first_encoding # id(object) -> (encoding, object) @@ -31,334 +371,37 @@ class _HostObjectMapper: return {encoding: obj for i, (encoding, obj) in self._d.items()} -_UserVariable = namedtuple("_UserVariable", "name") +def inline(core, k_function, k_args, k_kwargs): + if k_kwargs: + raise NotImplementedError( + "Keyword arguments are not supported for kernels") + # OrderedDict prevents non-determinism in parameter init + attribute_namespace = OrderedDict() + in_use_names = {func.__name__ for func in embeddable_funcs} + mappers = types.SimpleNamespace( + rpc=HostObjectMapper(), + exception=HostObjectMapper(core_language.first_user_eid) + ) + func_def = get_inline( + core=core, + attribute_namespace=attribute_namespace, + in_use_names=in_use_names, + retval_name=None, + mappers=mappers, + func=k_function, + args=k_args) -def _is_kernel_attr(value, attr): - return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split() - - -class _ReferenceManager: - def __init__(self): - self.rpc_mapper = _HostObjectMapper() - self.exception_mapper = _HostObjectMapper(core_language.first_user_eid) - self.kernel_attr_init = [] - - # (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name) - # -> _UserVariable(name) / complex object - self._to_inlined = dict() - # inlined_name -> use_count - self._use_count = dict() - # reserved names - for kg in core_language.kernel_globals: - self._use_count[kg] = 1 - for name in ("bool", "int", "round", "int64", "round64", "float", - "Fraction", "array", "Quantity", "EncodedException", - "range"): - self._use_count[name] = 1 - - # Complex objects in the namespace of functions can be used in two ways: - # 1. Calling a method on them (which gets inlined or RPCd) - # 2. Getting or setting attributes (which are turned into local variables) - # They are needed to implement "self", which is the only supported use - # case. - def register_complex_object(self, obj, func_name, ref_name, - complex_object): - assert(not isinstance(complex_object, ast.AST)) - self._to_inlined[(id(obj), func_name, ref_name)] = complex_object - - def new_name(self, base_name): - if base_name[-1].isdigit(): - base_name += "_" - if base_name in self._use_count: - r = base_name + str(self._use_count[base_name]) - self._use_count[base_name] += 1 - return r - else: - self._use_count[base_name] = 1 - return base_name - - def resolve_name(self, obj, func_name, ref_name, store): - key = (id(obj), func_name, ref_name) - try: - return self._to_inlined[key] - except KeyError: - if store: - ival = _UserVariable(self.new_name(ref_name)) - self._to_inlined[key] = ival - return ival - else: - try: - return inspect.getmodule(obj).__dict__[ref_name] - except KeyError: - return getattr(builtins, ref_name) - - def resolve_attr(self, value, attr): - if _is_kernel_attr(value, attr): - key = (id(value), attr) - try: - ival = self._to_inlined[key] - assert(isinstance(ival, _UserVariable)) - except KeyError: - iname = self.new_name(attr) - ival = _UserVariable(iname) - self._to_inlined[key] = ival - a = value_to_ast(getattr(value, attr)) - if a is None: - raise NotImplementedError( - "Cannot represent initial value" - " of kernel attribute") - self.kernel_attr_init.append(ast.Assign( - [ast.Name(iname, ast.Store())], a)) - return ival - else: - return getattr(value, attr) - - def resolve_constant(self, obj, func_name, node): - if isinstance(node, ast.Name): - c = self.resolve_name(obj, func_name, node.id, False) - if isinstance(c, _UserVariable): - raise ValueError("Not a constant") - return c - elif isinstance(node, ast.Attribute): - value = self.resolve_constant(obj, func_name, node.value) - if _is_kernel_attr(value, node.attr): - raise ValueError("Not a constant") - return getattr(value, node.attr) - else: - raise NotImplementedError - - -_embeddable_funcs = ( - core_language.delay, core_language.at, core_language.now, - core_language.time_to_cycles, core_language.cycles_to_time, - core_language.syscall, - range, bool, int, float, round, - core_language.int64, core_language.round64, core_language.array, - Fraction, units.Quantity, core_language.EncodedException -) - -def _is_embeddable(func): - for ef in _embeddable_funcs: - if func is ef: - return True - return False - - -def _is_inlinable(core, func): - if hasattr(func, "k_function_info"): - if func.k_function_info.core_name == "": - return True # portable function - if getattr(func.__self__, func.k_function_info.core_name) is core: - return True # kernel function for the same core device - return False - - -class _ReferenceReplacer(ast.NodeVisitor): - def __init__(self, core, rm, obj, func_name, retval_name): - self.core = core - self.rm = rm - self.obj = obj - self.func_name = func_name - self.retval_name = retval_name - self._insertion_point = None - - # This is ast.NodeTransformer.generic_visit from CPython, modified - # to update self._insertion_point. - def generic_visit(self, node): - for field, old_value in ast.iter_fields(node): - old_value = getattr(node, field, None) - if isinstance(old_value, list): - prev_insertion_point = self._insertion_point - new_values = [] - if field in ("body", "orelse", "finalbody"): - self._insertion_point = new_values - for value in old_value: - if isinstance(value, ast.AST): - value = self.visit(value) - if value is None: - continue - elif not isinstance(value, ast.AST): - new_values.extend(value) - continue - new_values.append(value) - old_value[:] = new_values - self._insertion_point = prev_insertion_point - elif isinstance(old_value, ast.AST): - new_node = self.visit(old_value) - if new_node is None: - delattr(node, field) - else: - setattr(node, field, new_node) - return node - - def visit_Name(self, node): - store = isinstance(node.ctx, ast.Store) - ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store) - if isinstance(ival, _UserVariable): - newnode = ast.Name(ival.name, node.ctx) - else: - if store: - raise NotImplementedError("Cannot assign to this object") - newnode = value_to_ast(ival) - return ast.copy_location(newnode, node) - - def _resolve_attribute(self, node): - if isinstance(node, ast.Name): - ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False) - if isinstance(ival, _UserVariable): - return ast.Name(ival.name, ast.Load()) - else: - return ival - elif isinstance(node, ast.Attribute): - value = self._resolve_attribute(node.value) - if isinstance(value, ast.AST): - node.value = value - return node - else: - return self.rm.resolve_attr(value, node.attr) - else: - return self.visit(node) - - def visit_Attribute(self, node): - ival = self._resolve_attribute(node) - if isinstance(ival, ast.AST): - newnode = deepcopy(ival) - elif isinstance(ival, _UserVariable): - newnode = ast.Name(ival.name, node.ctx) - else: - newnode = value_to_ast(ival) - return ast.copy_location(newnode, node) - - def visit_Call(self, node): - func = self.rm.resolve_constant(self.obj, self.func_name, node.func) - new_args = [self.visit(arg) for arg in node.args] - - if _is_embeddable(func): - new_func = ast.Name(func.__name__, ast.Load()) - return ast.copy_location( - ast.Call(func=new_func, args=new_args, - keywords=[], starargs=None, kwargs=None), - node) - elif _is_inlinable(self.core, func): - retval_name = self.rm.new_name( - func.k_function_info.k_function.__name__ + "_return") - args = [func.__self__] + new_args - inlined, _, _ = inline(self.core, func.k_function_info.k_function, - args, dict(), self.rm, retval_name) - self._insertion_point.append(ast.With( - items=[ast.withitem(context_expr=ast.Name(id="sequential", - ctx=ast.Load()), - optional_vars=None)], - body=inlined.body)) - return ast.copy_location(ast.Name(retval_name, ast.Load()), node) - else: - args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))] - args += new_args - return ast.copy_location( - ast.Call(func=ast.Name("syscall", ast.Load()), - args=args, keywords=[], starargs=None, kwargs=None), - node) - - def visit_Return(self, node): - self.generic_visit(node) - return ast.copy_location( - ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())], - value=node.value), - node) - - def visit_Expr(self, node): - if isinstance(node.value, ast.Str): - # Strip docstrings. This also removes strings appearing in the - # middle of the code, but they are nops. - return None - self.generic_visit(node) - if isinstance(node.value, ast.Name): - # Remove Expr nodes that contain only a name, likely due to - # function call inlining. Such nodes that were originally in the - # code are also removed, but this does not affect the semantics of - # the code as they are nops. - return None - else: - return node - - def visit_FunctionDef(self, node): - node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], - kw_defaults=[], kwarg=None, defaults=[]) - node.decorator_list = [] - self.generic_visit(node) - return node - - def _encode_exception(self, e): - exception_class = self.rm.resolve_constant(self.obj, self.func_name, e) - if not inspect.isclass(exception_class): - raise NotImplementedError("Exception type must be a class") - if issubclass(exception_class, core_language.RuntimeException): - exception_id = exception_class.eid - else: - exception_id = self.rm.exception_mapper.encode(exception_class) - return ast.copy_location( - ast.Call(func=ast.Name("EncodedException", ast.Load()), - args=[value_to_ast(exception_id)], - keywords=[], starargs=None, kwargs=None), - e) - - def visit_Raise(self, node): - if node.cause is not None: - raise NotImplementedError("Exception causes are not supported") - if node.exc is not None: - node.exc = self._encode_exception(node.exc) - return node - - def visit_ExceptHandler(self, node): - if node.name is not None: - raise NotImplementedError("'as target' is not supported") - if node.type is not None: - if isinstance(node.type, ast.Tuple): - node.type.elts = [self._encode_exception(e) for e in node.type.elts] - else: - node.type = self._encode_exception(node.type) - self.generic_visit(node) - return node - - -def _initialize_function_params(func_def, k_args, k_kwargs, rm): - obj = k_args[0] - func_name = func_def.name param_init = [] - for arg_ast, arg_value in zip(func_def.args.args, k_args): - if isinstance(arg_value, ast.AST): - value = arg_value - else: - try: - value = value_to_ast(arg_value) - except NotASTRepresentable: - value = None - if value is None: - rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value) - else: - uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True) - target = ast.Name(uservar.name, ast.Store()) - param_init.append(ast.Assign(targets=[target], value=value)) - return param_init - - -def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None): - init_kernel_attr = rm is None - if rm is None: - rm = _ReferenceManager() - - func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] - - param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm) - - obj = k_args[0] - func_name = func_def.name - rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name) - rr.visit(func_def) - + for (_, attr), attr_info in attribute_namespace.items(): + value = getattr(attr_info.obj, attr) + value = ast.copy_location(value_to_ast(value), func_def) + target = ast.copy_location(ast.Name(attr_info.mangled_name, + ast.Store()), + func_def) + assign = ast.copy_location(ast.Assign([target], value), + func_def) + param_init.append(assign) func_def.body[0:0] = param_init - if init_kernel_attr: - func_def.body[0:0] = rm.kernel_attr_init - return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map() + return func_def, mappers.rpc.get_map(), mappers.exception.get_map()