artiq/artiq/transforms/inline.py

494 lines
19 KiB
Python

import inspect
import textwrap
import ast
import types
import builtins
from copy import copy
from fractions import Fraction
from collections import OrderedDict
from functools import partial
from itertools import zip_longest, chain
from artiq.language import core as core_language
from artiq.language import units
from artiq.transforms.tools import *
def new_mangled_name(in_use_names, name):
mangled_name = name
i = 2
while mangled_name in in_use_names:
mangled_name = name + str(i)
i += 1
in_use_names.add(mangled_name)
return mangled_name
class MangledName:
def __init__(self, s):
self.s = s
class AttributeInfo:
def __init__(self, obj, mangled_name, read_write):
self.obj = obj
self.mangled_name = mangled_name
self.read_write = read_write
def is_inlinable(core, func):
if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "":
return True # portable function
if getattr(func.__self__, func.k_function_info.core_name) is core:
return True # kernel function for the same core device
return False
class GlobalNamespace:
def __init__(self, func):
self.func_gd = inspect.getmodule(func).__dict__
def __getitem__(self, item):
try:
return self.func_gd[item]
except KeyError:
return getattr(builtins, item)
class UndefinedArg:
pass
def get_function_args(func_args, func_tr, args, kwargs):
# OrderedDict prevents non-determinism in argument init
r = OrderedDict()
# Process positional arguments. Any missing positional argument values
# are set to UndefinedArg.
for arg, arg_value in zip_longest(func_args.args, args,
fillvalue=UndefinedArg):
if arg is UndefinedArg:
raise TypeError("Got too many positional arguments")
if arg.arg in r:
raise SyntaxError("Duplicate argument '{}' in function definition"
.format(arg.arg))
r[arg.arg] = arg_value
# Process keyword arguments. Any missing keyword-only argument values
# are set to UndefinedArg.
valid_arg_names = {arg.arg for arg in
chain(func_args.args, func_args.kwonlyargs)}
for arg in func_args.kwonlyargs:
if arg.arg in r:
raise SyntaxError("Duplicate argument '{}' in function definition"
.format(arg.arg))
r[arg.arg] = UndefinedArg
for arg_name, arg_value in kwargs.items():
if arg_name not in valid_arg_names:
raise TypeError("Got unexpected keyword argument '{}'"
.format(arg_name))
if r[arg_name] is not UndefinedArg:
raise TypeError("Got multiple values for argument '{}'"
.format(arg_name))
r[arg_name] = arg_value
# Replace any UndefinedArg positional arguments with the default value,
# when provided.
for arg, default in zip(func_args.args[-len(func_args.defaults):],
func_args.defaults):
if r[arg.arg] is UndefinedArg:
r[arg.arg] = func_tr.code_visit(default)
# Same with keyword-only arguments.
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
if default is not None and r[arg.arg] is UndefinedArg:
r[arg.arg] = func_tr.code_visit(default)
# Check that no argument was left undefined.
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
if value is UndefinedArg]
if missing_arguments:
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
return r
# args/kwargs can contain values or AST nodes
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
func, args, kwargs):
global_namespace = GlobalNamespace(func)
func_tr = Function(core,
global_namespace, attribute_namespace, in_use_names,
retval_name, mappers)
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
# Initialize arguments.
# The local namespace is empty so code_visit will always resolve
# using the global namespace.
arg_init = []
arg_name_map = []
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
for arg_name, arg_value in arg_dict.items():
if isinstance(arg_value, ast.AST):
value = arg_value
else:
try:
value = ast.copy_location(value_to_ast(arg_value), func_def)
except NotASTRepresentable:
value = None
if value is None:
# static object
func_tr.local_namespace[arg_name] = arg_value
else:
# set parameter value with "name = value"
# assignment at beginning of function
new_name = new_mangled_name(in_use_names, arg_name)
arg_name_map.append((arg_name, new_name))
target = ast.copy_location(ast.Name(new_name, ast.Store()),
func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
arg_init.append(assign)
# Commit arguments to the local namespace at the end to handle cases
# such as f(x, y=x) (for the default value of y, x must be resolved
# using the global namespace).
for arg_name, mangled_name in arg_name_map:
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
func_def = func_tr.code_visit(func_def)
func_def.body[0:0] = arg_init
return func_def
class Function:
def __init__(self, core,
global_namespace, attribute_namespace, in_use_names,
retval_name, mappers):
# The core device on which this function is executing.
self.core = core
# Local and global namespaces:
# original name -> MangledName or static object
self.local_namespace = dict()
self.global_namespace = global_namespace
# (id(static object), attribute) -> AttributeInfo
self.attribute_namespace = attribute_namespace
# All names currently in use, in the namespace of the combined
# function.
# When creating a name for a new object, check that it is not
# already in this set.
self.in_use_names = in_use_names
# Name of the variable to store the return value to, or None
# to keep the return statement.
self.retval_name = retval_name
# Host object mappers, for RPC and exception numbers
self.mappers = mappers
self._insertion_point = None
# This is ast.NodeVisitor/NodeTransformer from CPython, modified
# to add code_ prefix.
def code_visit(self, node):
method = "code_visit_" + node.__class__.__name__
visitor = getattr(self, method, self.code_generic_visit)
return visitor(node)
# This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point.
def code_generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
prev_insertion_point = self._insertion_point
new_values = []
if field in ("body", "orelse", "finalbody"):
self._insertion_point = new_values
for value in old_value:
if isinstance(value, ast.AST):
value = self.code_visit(value)
if value is None:
continue
elif not isinstance(value, ast.AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
self._insertion_point = prev_insertion_point
elif isinstance(old_value, ast.AST):
new_node = self.code_visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def code_visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
if (node.id in self.local_namespace
and isinstance(self.local_namespace[node.id],
MangledName)):
new_name = self.local_namespace[node.id].s
else:
new_name = new_mangled_name(self.in_use_names, node.id)
self.local_namespace[node.id] = MangledName(new_name)
node.id = new_name
return node
else:
try:
obj = self.local_namespace[node.id]
except KeyError:
try:
obj = self.global_namespace[node.id]
except KeyError:
raise NameError("name '{}' is not defined".format(node.id))
if isinstance(obj, MangledName):
node.id = obj.s
return node
else:
try:
return value_to_ast(obj)
except NotASTRepresentable:
raise NotImplementedError(
"Static object cannot be used here")
def code_visit_Attribute(self, node):
# There are two cases of attributes:
# 1. static object attributes, e.g. self.foo
# 2. dynamic expression attributes, e.g.
# (Fraction(1, 2) + x).numerator
# Static object resolution has no side effects so we try it first.
try:
obj = self.static_visit(node.value)
except:
self.code_generic_visit(node)
return node
else:
key = (id(obj), node.attr)
try:
attr_info = self.attribute_namespace[key]
except KeyError:
new_name = new_mangled_name(self.in_use_names, node.attr)
attr_info = AttributeInfo(obj, new_name, False)
self.attribute_namespace[key] = attr_info
if isinstance(node.ctx, ast.Store):
attr_info.read_write = True
return ast.copy_location(
ast.Name(attr_info.mangled_name, node.ctx),
node)
def code_visit_Call(self, node):
func = self.static_visit(node.func)
node.args = [self.code_visit(arg) for arg in node.args]
for kw in node.keywords:
kw.value = self.code_visit(kw.value)
if is_embeddable(func):
node.func = ast.copy_location(
ast.Name(func.__name__, ast.Load()),
node)
return node
elif is_inlinable(self.core, func):
retval_name = func.k_function_info.k_function.__name__ + "_return"
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
args = [func.__self__] + node.args
kwargs = {kw.arg: kw.value for kw in node.keywords}
inlined = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
retval_name_m, self.mappers,
func.k_function_info.k_function,
args, kwargs)
seq = ast.copy_location(
ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=inlined.body),
node)
self._insertion_point.append(seq)
return ast.copy_location(ast.Name(retval_name_m, ast.Load()),
node)
else:
arg1 = ast.copy_location(ast.Str("rpc"), node)
arg2 = ast.copy_location(
value_to_ast(self.mappers.rpc.encode(func)), node)
node.args[0:0] = [arg1, arg2]
node.func = ast.copy_location(
ast.Name("syscall", ast.Load()), node)
return node
def code_visit_Return(self, node):
self.code_generic_visit(node)
if self.retval_name is None:
return node
else:
return ast.copy_location(
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
value=node.value),
node)
def code_visit_Expr(self, node):
if isinstance(node.value, ast.Str):
# Strip docstrings. This also removes strings appearing in the
# middle of the code, but they are nops.
return None
self.code_generic_visit(node)
if isinstance(node.value, ast.Name):
# Remove Expr nodes that contain only a name, likely due to
# function call inlining. Such nodes that were originally in the
# code are also removed, but this does not affect the semantics of
# the code as they are nops.
return None
else:
return node
def encode_exception(self, e):
exception_class = self.static_visit(e)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class")
if issubclass(exception_class, core_language.RuntimeException):
exception_id = exception_class.eid
else:
exception_id = self.mappers.exception.encode(exception_class)
return ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)],
keywords=[], starargs=None, kwargs=None),
e)
def code_visit_Raise(self, node):
if node.cause is not None:
raise NotImplementedError("Exception causes are not supported")
if node.exc is not None:
node.exc = self.encode_exception(node.exc)
return node
def code_visit_ExceptHandler(self, node):
if node.name is not None:
raise NotImplementedError("'as target' is not supported")
if node.type is not None:
if isinstance(node.type, ast.Tuple):
node.type.elts = [self.encode_exception(e)
for e in node.type.elts]
else:
node.type = self.encode_exception(node.type)
self.code_generic_visit(node)
return node
def code_visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.code_generic_visit(node)
return node
def static_visit(self, node):
method = "static_visit_" + node.__class__.__name__
visitor = getattr(self, method)
return visitor(node)
def static_visit_Name(self, node):
try:
obj = self.local_namespace[node.id]
except KeyError:
try:
obj = self.global_namespace[node.id]
except KeyError:
raise NameError("name '{}' is not defined".format(node.id))
if isinstance(obj, MangledName):
raise NotImplementedError(
"Only a static object can be used here")
return obj
def static_visit_Attribute(self, node):
value = self.static_visit(node.value)
return getattr(value, node.attr)
class HostObjectMapper:
def __init__(self, first_encoding=0):
self._next_encoding = first_encoding
# id(object) -> (encoding, object)
# this format is required to support non-hashable host objects.
self._d = dict()
def encode(self, obj):
try:
return self._d[id(obj)][0]
except KeyError:
encoding = self._next_encoding
self._d[id(obj)] = (encoding, obj)
self._next_encoding += 1
return encoding
def get_map(self):
return {encoding: obj for i, (encoding, obj) in self._d.items()}
def get_attr_init(attribute_namespace, loc_node):
attr_init = []
for (_, attr), attr_info in attribute_namespace.items():
if hasattr(attr_info.obj, attr):
value = getattr(attr_info.obj, attr)
if (hasattr(value, "kernel_attr_init")
and not value.kernel_attr_init):
continue
value = ast.copy_location(value_to_ast(value), loc_node)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
loc_node)
assign = ast.copy_location(ast.Assign([target], value),
loc_node)
attr_init.append(assign)
return attr_init
def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
attr_writeback = []
for (_, attr), attr_info in attribute_namespace.items():
if attr_info.read_write:
setter = partial(setattr, attr_info.obj, attr)
func = ast.copy_location(
ast.Name("syscall", ast.Load()), loc_node)
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
arg2 = ast.copy_location(
value_to_ast(rpc_mapper.encode(setter)), loc_node)
arg3 = ast.copy_location(
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
call = ast.copy_location(
ast.Call(func=func, args=[arg1, arg2, arg3],
keywords=[], starargs=None, kwargs=None),
loc_node)
expr = ast.copy_location(ast.Expr(call), loc_node)
attr_writeback.append(expr)
return attr_writeback
def inline(core, k_function, k_args, k_kwargs, with_attr_writeback):
# OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict()
in_use_names = copy(embeddable_func_names)
mappers = types.SimpleNamespace(
rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid)
)
func_def = get_inline(
core=core,
attribute_namespace=attribute_namespace,
in_use_names=in_use_names,
retval_name=None,
mappers=mappers,
func=k_function,
args=k_args,
kwargs=k_kwargs)
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
if with_attr_writeback:
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
func_def)
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()