2
0
mirror of https://github.com/m-labs/artiq.git synced 2025-01-27 02:48:12 +08:00

transforms/inline: rewrite

This commit is contained in:
Sebastien Bourdeauducq 2014-10-31 23:43:36 +08:00
parent 7806ca3373
commit cf7848c698

View File

@ -1,17 +1,357 @@
from collections import namedtuple
from fractions import Fraction
import inspect
import textwrap
import ast
import types
import builtins
from copy import deepcopy
from fractions import Fraction
from collections import OrderedDict
from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable
from artiq.language import core as core_language
from artiq.language import units
from artiq.transforms.tools import value_to_ast, NotASTRepresentable
class _HostObjectMapper:
def new_mangled_name(in_use_names, name):
mangled_name = name
i = 2
while mangled_name in in_use_names:
mangled_name = name + str(i)
i += 1
in_use_names.add(mangled_name)
return mangled_name
class MangledName:
def __init__(self, s):
self.s = s
class AttributeInfo:
def __init__(self, obj, mangled_name, read_write):
self.obj = obj
self.mangled_name = mangled_name
self.read_write = read_write
embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException
)
def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False
def is_inlinable(core, func):
if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "":
return True # portable function
if getattr(func.__self__, func.k_function_info.core_name) is core:
return True # kernel function for the same core device
return False
class GlobalNamespace:
def __init__(self, func):
self.func_gd = inspect.getmodule(func).__dict__
def __getitem__(self, item):
try:
return self.func_gd[item]
except KeyError:
return getattr(builtins, item)
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
func, args):
global_namespace = GlobalNamespace(func)
func_tr = Function(core,
global_namespace, attribute_namespace, in_use_names,
retval_name, mappers)
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
param_init = []
# initialize arguments
for arg_ast, arg_value in zip(func_def.args.args, args):
if isinstance(arg_value, ast.AST):
value = arg_value
else:
try:
value = ast.copy_location(value_to_ast(arg_value), func_def)
except NotASTRepresentable:
value = None
if value is None:
# static object
func_tr.local_namespace[arg_ast.arg] = arg_value
else:
# set parameter value with "name = value"
# assignment at beginning of function
new_name = new_mangled_name(in_use_names, arg_ast.arg)
func_tr.local_namespace[arg_ast.arg] = MangledName(new_name)
target = ast.copy_location(ast.Name(new_name, ast.Store()),
func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
param_init.append(assign)
func_def = func_tr.code_visit(func_def)
func_def.body[0:0] = param_init
return func_def
class Function:
def __init__(self, core,
global_namespace, attribute_namespace, in_use_names,
retval_name, mappers):
# The core device on which this function is executing.
self.core = core
# Local and global namespaces:
# original name -> MangledName or static object
self.local_namespace = dict()
self.global_namespace = global_namespace
# (id(static object), attribute) -> AttributeInfo
self.attribute_namespace = attribute_namespace
# All names currently in use, in the namespace of the combined
# function.
# When creating a name for a new object, check that it is not
# already in this set.
self.in_use_names = in_use_names
# Name of the variable to store the return value to, or None
# to keep the return statement.
self.retval_name = retval_name
# Host object mappers, for RPC and exception numbers
self.mappers = mappers
self._insertion_point = None
# This is ast.NodeVisitor/NodeTransformer from CPython, modified
# to add code_ prefix.
def code_visit(self, node):
method = "code_visit_" + node.__class__.__name__
visitor = getattr(self, method, self.code_generic_visit)
return visitor(node)
# This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point.
def code_generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
prev_insertion_point = self._insertion_point
new_values = []
if field in ("body", "orelse", "finalbody"):
self._insertion_point = new_values
for value in old_value:
if isinstance(value, ast.AST):
value = self.code_visit(value)
if value is None:
continue
elif not isinstance(value, ast.AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
self._insertion_point = prev_insertion_point
elif isinstance(old_value, ast.AST):
new_node = self.code_visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def code_visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
if (node.id in self.local_namespace
and isinstance(self.local_namespace[node.id],
MangledName)):
new_name = self.local_namespace[node.id].s
else:
new_name = new_mangled_name(self.in_use_names, node.id)
self.local_namespace[node.id] = MangledName(new_name)
node.id = new_name
return node
else:
try:
obj = self.local_namespace[node.id]
except KeyError:
try:
obj = self.global_namespace[node.id]
except KeyError:
raise NameError("name '{}' is not defined".format(node.id))
if isinstance(obj, MangledName):
node.id = obj.s
return node
else:
try:
return value_to_ast(obj)
except NotASTRepresentable:
raise NotImplementedError(
"Static object cannot be used here")
def code_visit_Attribute(self, node):
# There are two cases of attributes:
# 1. static object attributes, e.g. self.foo
# 2. dynamic expression attributes, e.g.
# (Fraction(1, 2) + x).numerator
# Static object resolution has no side effects so we try it first.
try:
obj = self.static_visit(node.value)
except:
self.code_generic_visit(node)
return node
else:
key = (id(obj), node.attr)
try:
attr_info = self.attribute_namespace[key]
except KeyError:
new_name = new_mangled_name(self.in_use_names, node.attr)
attr_info = AttributeInfo(obj, new_name, False)
self.attribute_namespace[key] = attr_info
if isinstance(node.ctx, ast.Store):
attr_info.read_write = True
return ast.copy_location(
ast.Name(attr_info.mangled_name, node.ctx),
node)
def code_visit_Call(self, node):
func = self.static_visit(node.func)
node.args = [self.code_visit(arg) for arg in node.args]
if is_embeddable(func):
node.func = ast.copy_location(
ast.Name(func.__name__, ast.Load()),
node)
return node
elif is_inlinable(self.core, func):
retval_name = func.k_function_info.k_function.__name__ + "_return"
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
args = [func.__self__] + node.args
inlined = get_inline(self.core,
self.attribute_namespace, self.in_use_names,
retval_name_m, self.mappers,
func.k_function_info.k_function, args)
seq = ast.copy_location(
ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=inlined.body),
node)
self._insertion_point.append(seq)
return ast.copy_location(ast.Name(retval_name_m, ast.Load()),
node)
else:
arg1 = ast.copy_location(ast.Str("rpc"), node)
arg2 = ast.copy_location(
value_to_ast(self.mappers.rpc.encode(func)), node)
node.args[0:0] = [arg1, arg2]
node.func = ast.copy_location(
ast.Name("syscall", ast.Load()), node)
return node
def code_visit_Return(self, node):
self.code_generic_visit(node)
if self.retval_name is None:
return node
else:
return ast.copy_location(
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
value=node.value),
node)
def code_visit_Expr(self, node):
if isinstance(node.value, ast.Str):
# Strip docstrings. This also removes strings appearing in the
# middle of the code, but they are nops.
return None
self.code_generic_visit(node)
if isinstance(node.value, ast.Name):
# Remove Expr nodes that contain only a name, likely due to
# function call inlining. Such nodes that were originally in the
# code are also removed, but this does not affect the semantics of
# the code as they are nops.
return None
else:
return node
def encode_exception(self, e):
exception_class = self.static_visit(e)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class")
if issubclass(exception_class, core_language.RuntimeException):
exception_id = exception_class.eid
else:
exception_id = self.mappers.exception.encode(exception_class)
return ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)],
keywords=[], starargs=None, kwargs=None),
e)
def code_visit_Raise(self, node):
if node.cause is not None:
raise NotImplementedError("Exception causes are not supported")
if node.exc is not None:
node.exc = self.encode_exception(node.exc)
return node
def code_visit_ExceptHandler(self, node):
if node.name is not None:
raise NotImplementedError("'as target' is not supported")
if node.type is not None:
if isinstance(node.type, ast.Tuple):
node.type.elts = [self.encode_exception(e)
for e in node.type.elts]
else:
node.type = self.encode_exception(node.type)
self.code_generic_visit(node)
return node
def code_visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.code_generic_visit(node)
return node
def static_visit(self, node):
method = "static_visit_" + node.__class__.__name__
visitor = getattr(self, method)
return visitor(node)
def static_visit_Name(self, node):
try:
obj = self.local_namespace[node.id]
except KeyError:
try:
obj = self.global_namespace[node.id]
except KeyError:
raise NameError("name '{}' is not defined".format(node.id))
if isinstance(obj, MangledName):
raise NotImplementedError(
"Only a static object can be used here")
return obj
def static_visit_Attribute(self, node):
value = self.static_visit(node.value)
return getattr(value, node.attr)
class HostObjectMapper:
def __init__(self, first_encoding=0):
self._next_encoding = first_encoding
# id(object) -> (encoding, object)
@ -31,334 +371,37 @@ class _HostObjectMapper:
return {encoding: obj for i, (encoding, obj) in self._d.items()}
_UserVariable = namedtuple("_UserVariable", "name")
def inline(core, k_function, k_args, k_kwargs):
if k_kwargs:
raise NotImplementedError(
"Keyword arguments are not supported for kernels")
# OrderedDict prevents non-determinism in parameter init
attribute_namespace = OrderedDict()
in_use_names = {func.__name__ for func in embeddable_funcs}
mappers = types.SimpleNamespace(
rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid)
)
func_def = get_inline(
core=core,
attribute_namespace=attribute_namespace,
in_use_names=in_use_names,
retval_name=None,
mappers=mappers,
func=k_function,
args=k_args)
def _is_kernel_attr(value, attr):
return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split()
class _ReferenceManager:
def __init__(self):
self.rpc_mapper = _HostObjectMapper()
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
self.kernel_attr_init = []
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
# -> _UserVariable(name) / complex object
self._to_inlined = dict()
# inlined_name -> use_count
self._use_count = dict()
# reserved names
for kg in core_language.kernel_globals:
self._use_count[kg] = 1
for name in ("bool", "int", "round", "int64", "round64", "float",
"Fraction", "array", "Quantity", "EncodedException",
"range"):
self._use_count[name] = 1
# Complex objects in the namespace of functions can be used in two ways:
# 1. Calling a method on them (which gets inlined or RPCd)
# 2. Getting or setting attributes (which are turned into local variables)
# They are needed to implement "self", which is the only supported use
# case.
def register_complex_object(self, obj, func_name, ref_name,
complex_object):
assert(not isinstance(complex_object, ast.AST))
self._to_inlined[(id(obj), func_name, ref_name)] = complex_object
def new_name(self, base_name):
if base_name[-1].isdigit():
base_name += "_"
if base_name in self._use_count:
r = base_name + str(self._use_count[base_name])
self._use_count[base_name] += 1
return r
else:
self._use_count[base_name] = 1
return base_name
def resolve_name(self, obj, func_name, ref_name, store):
key = (id(obj), func_name, ref_name)
try:
return self._to_inlined[key]
except KeyError:
if store:
ival = _UserVariable(self.new_name(ref_name))
self._to_inlined[key] = ival
return ival
else:
try:
return inspect.getmodule(obj).__dict__[ref_name]
except KeyError:
return getattr(builtins, ref_name)
def resolve_attr(self, value, attr):
if _is_kernel_attr(value, attr):
key = (id(value), attr)
try:
ival = self._to_inlined[key]
assert(isinstance(ival, _UserVariable))
except KeyError:
iname = self.new_name(attr)
ival = _UserVariable(iname)
self._to_inlined[key] = ival
a = value_to_ast(getattr(value, attr))
if a is None:
raise NotImplementedError(
"Cannot represent initial value"
" of kernel attribute")
self.kernel_attr_init.append(ast.Assign(
[ast.Name(iname, ast.Store())], a))
return ival
else:
return getattr(value, attr)
def resolve_constant(self, obj, func_name, node):
if isinstance(node, ast.Name):
c = self.resolve_name(obj, func_name, node.id, False)
if isinstance(c, _UserVariable):
raise ValueError("Not a constant")
return c
elif isinstance(node, ast.Attribute):
value = self.resolve_constant(obj, func_name, node.value)
if _is_kernel_attr(value, node.attr):
raise ValueError("Not a constant")
return getattr(value, node.attr)
else:
raise NotImplementedError
_embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException
)
def _is_embeddable(func):
for ef in _embeddable_funcs:
if func is ef:
return True
return False
def _is_inlinable(core, func):
if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "":
return True # portable function
if getattr(func.__self__, func.k_function_info.core_name) is core:
return True # kernel function for the same core device
return False
class _ReferenceReplacer(ast.NodeVisitor):
def __init__(self, core, rm, obj, func_name, retval_name):
self.core = core
self.rm = rm
self.obj = obj
self.func_name = func_name
self.retval_name = retval_name
self._insertion_point = None
# This is ast.NodeTransformer.generic_visit from CPython, modified
# to update self._insertion_point.
def generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
prev_insertion_point = self._insertion_point
new_values = []
if field in ("body", "orelse", "finalbody"):
self._insertion_point = new_values
for value in old_value:
if isinstance(value, ast.AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, ast.AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
self._insertion_point = prev_insertion_point
elif isinstance(old_value, ast.AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def visit_Name(self, node):
store = isinstance(node.ctx, ast.Store)
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
if isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx)
else:
if store:
raise NotImplementedError("Cannot assign to this object")
newnode = value_to_ast(ival)
return ast.copy_location(newnode, node)
def _resolve_attribute(self, node):
if isinstance(node, ast.Name):
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False)
if isinstance(ival, _UserVariable):
return ast.Name(ival.name, ast.Load())
else:
return ival
elif isinstance(node, ast.Attribute):
value = self._resolve_attribute(node.value)
if isinstance(value, ast.AST):
node.value = value
return node
else:
return self.rm.resolve_attr(value, node.attr)
else:
return self.visit(node)
def visit_Attribute(self, node):
ival = self._resolve_attribute(node)
if isinstance(ival, ast.AST):
newnode = deepcopy(ival)
elif isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx)
else:
newnode = value_to_ast(ival)
return ast.copy_location(newnode, node)
def visit_Call(self, node):
func = self.rm.resolve_constant(self.obj, self.func_name, node.func)
new_args = [self.visit(arg) for arg in node.args]
if _is_embeddable(func):
new_func = ast.Name(func.__name__, ast.Load())
return ast.copy_location(
ast.Call(func=new_func, args=new_args,
keywords=[], starargs=None, kwargs=None),
node)
elif _is_inlinable(self.core, func):
retval_name = self.rm.new_name(
func.k_function_info.k_function.__name__ + "_return")
args = [func.__self__] + new_args
inlined, _, _ = inline(self.core, func.k_function_info.k_function,
args, dict(), self.rm, retval_name)
self._insertion_point.append(ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential",
ctx=ast.Load()),
optional_vars=None)],
body=inlined.body))
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
else:
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))]
args += new_args
return ast.copy_location(
ast.Call(func=ast.Name("syscall", ast.Load()),
args=args, keywords=[], starargs=None, kwargs=None),
node)
def visit_Return(self, node):
self.generic_visit(node)
return ast.copy_location(
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
value=node.value),
node)
def visit_Expr(self, node):
if isinstance(node.value, ast.Str):
# Strip docstrings. This also removes strings appearing in the
# middle of the code, but they are nops.
return None
self.generic_visit(node)
if isinstance(node.value, ast.Name):
# Remove Expr nodes that contain only a name, likely due to
# function call inlining. Such nodes that were originally in the
# code are also removed, but this does not affect the semantics of
# the code as they are nops.
return None
else:
return node
def visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.generic_visit(node)
return node
def _encode_exception(self, e):
exception_class = self.rm.resolve_constant(self.obj, self.func_name, e)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class")
if issubclass(exception_class, core_language.RuntimeException):
exception_id = exception_class.eid
else:
exception_id = self.rm.exception_mapper.encode(exception_class)
return ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)],
keywords=[], starargs=None, kwargs=None),
e)
def visit_Raise(self, node):
if node.cause is not None:
raise NotImplementedError("Exception causes are not supported")
if node.exc is not None:
node.exc = self._encode_exception(node.exc)
return node
def visit_ExceptHandler(self, node):
if node.name is not None:
raise NotImplementedError("'as target' is not supported")
if node.type is not None:
if isinstance(node.type, ast.Tuple):
node.type.elts = [self._encode_exception(e) for e in node.type.elts]
else:
node.type = self._encode_exception(node.type)
self.generic_visit(node)
return node
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
obj = k_args[0]
func_name = func_def.name
param_init = []
for arg_ast, arg_value in zip(func_def.args.args, k_args):
if isinstance(arg_value, ast.AST):
value = arg_value
else:
try:
value = value_to_ast(arg_value)
except NotASTRepresentable:
value = None
if value is None:
rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value)
else:
uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True)
target = ast.Name(uservar.name, ast.Store())
param_init.append(ast.Assign(targets=[target], value=value))
return param_init
def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
init_kernel_attr = rm is None
if rm is None:
rm = _ReferenceManager()
func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm)
obj = k_args[0]
func_name = func_def.name
rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name)
rr.visit(func_def)
for (_, attr), attr_info in attribute_namespace.items():
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), func_def)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
param_init.append(assign)
func_def.body[0:0] = param_init
if init_kernel_attr:
func_def.body[0:0] = rm.kernel_attr_init
return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()