forked from M-Labs/artiq
transforms/inline: rewrite
This commit is contained in:
parent
7806ca3373
commit
cf7848c698
|
@ -1,17 +1,357 @@
|
||||||
from collections import namedtuple
|
|
||||||
from fractions import Fraction
|
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
import ast
|
import ast
|
||||||
|
import types
|
||||||
import builtins
|
import builtins
|
||||||
from copy import deepcopy
|
from fractions import Fraction
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from artiq.transforms.tools import eval_ast, value_to_ast, NotASTRepresentable
|
|
||||||
from artiq.language import core as core_language
|
from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
from artiq.transforms.tools import value_to_ast, NotASTRepresentable
|
||||||
|
|
||||||
|
|
||||||
class _HostObjectMapper:
|
def new_mangled_name(in_use_names, name):
|
||||||
|
mangled_name = name
|
||||||
|
i = 2
|
||||||
|
while mangled_name in in_use_names:
|
||||||
|
mangled_name = name + str(i)
|
||||||
|
i += 1
|
||||||
|
in_use_names.add(mangled_name)
|
||||||
|
return mangled_name
|
||||||
|
|
||||||
|
|
||||||
|
class MangledName:
|
||||||
|
def __init__(self, s):
|
||||||
|
self.s = s
|
||||||
|
|
||||||
|
|
||||||
|
class AttributeInfo:
|
||||||
|
def __init__(self, obj, mangled_name, read_write):
|
||||||
|
self.obj = obj
|
||||||
|
self.mangled_name = mangled_name
|
||||||
|
self.read_write = read_write
|
||||||
|
|
||||||
|
|
||||||
|
embeddable_funcs = (
|
||||||
|
core_language.delay, core_language.at, core_language.now,
|
||||||
|
core_language.time_to_cycles, core_language.cycles_to_time,
|
||||||
|
core_language.syscall,
|
||||||
|
range, bool, int, float, round,
|
||||||
|
core_language.int64, core_language.round64, core_language.array,
|
||||||
|
Fraction, units.Quantity, core_language.EncodedException
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_embeddable(func):
|
||||||
|
for ef in embeddable_funcs:
|
||||||
|
if func is ef:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_inlinable(core, func):
|
||||||
|
if hasattr(func, "k_function_info"):
|
||||||
|
if func.k_function_info.core_name == "":
|
||||||
|
return True # portable function
|
||||||
|
if getattr(func.__self__, func.k_function_info.core_name) is core:
|
||||||
|
return True # kernel function for the same core device
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalNamespace:
|
||||||
|
def __init__(self, func):
|
||||||
|
self.func_gd = inspect.getmodule(func).__dict__
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
try:
|
||||||
|
return self.func_gd[item]
|
||||||
|
except KeyError:
|
||||||
|
return getattr(builtins, item)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
||||||
|
func, args):
|
||||||
|
global_namespace = GlobalNamespace(func)
|
||||||
|
func_tr = Function(core,
|
||||||
|
global_namespace, attribute_namespace, in_use_names,
|
||||||
|
retval_name, mappers)
|
||||||
|
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
|
||||||
|
|
||||||
|
param_init = []
|
||||||
|
# initialize arguments
|
||||||
|
for arg_ast, arg_value in zip(func_def.args.args, args):
|
||||||
|
if isinstance(arg_value, ast.AST):
|
||||||
|
value = arg_value
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
value = ast.copy_location(value_to_ast(arg_value), func_def)
|
||||||
|
except NotASTRepresentable:
|
||||||
|
value = None
|
||||||
|
if value is None:
|
||||||
|
# static object
|
||||||
|
func_tr.local_namespace[arg_ast.arg] = arg_value
|
||||||
|
else:
|
||||||
|
# set parameter value with "name = value"
|
||||||
|
# assignment at beginning of function
|
||||||
|
new_name = new_mangled_name(in_use_names, arg_ast.arg)
|
||||||
|
func_tr.local_namespace[arg_ast.arg] = MangledName(new_name)
|
||||||
|
target = ast.copy_location(ast.Name(new_name, ast.Store()),
|
||||||
|
func_def)
|
||||||
|
assign = ast.copy_location(ast.Assign([target], value),
|
||||||
|
func_def)
|
||||||
|
param_init.append(assign)
|
||||||
|
|
||||||
|
func_def = func_tr.code_visit(func_def)
|
||||||
|
func_def.body[0:0] = param_init
|
||||||
|
return func_def
|
||||||
|
|
||||||
|
|
||||||
|
class Function:
|
||||||
|
def __init__(self, core,
|
||||||
|
global_namespace, attribute_namespace, in_use_names,
|
||||||
|
retval_name, mappers):
|
||||||
|
# The core device on which this function is executing.
|
||||||
|
self.core = core
|
||||||
|
|
||||||
|
# Local and global namespaces:
|
||||||
|
# original name -> MangledName or static object
|
||||||
|
self.local_namespace = dict()
|
||||||
|
self.global_namespace = global_namespace
|
||||||
|
|
||||||
|
# (id(static object), attribute) -> AttributeInfo
|
||||||
|
self.attribute_namespace = attribute_namespace
|
||||||
|
|
||||||
|
# All names currently in use, in the namespace of the combined
|
||||||
|
# function.
|
||||||
|
# When creating a name for a new object, check that it is not
|
||||||
|
# already in this set.
|
||||||
|
self.in_use_names = in_use_names
|
||||||
|
|
||||||
|
# Name of the variable to store the return value to, or None
|
||||||
|
# to keep the return statement.
|
||||||
|
self.retval_name = retval_name
|
||||||
|
|
||||||
|
# Host object mappers, for RPC and exception numbers
|
||||||
|
self.mappers = mappers
|
||||||
|
|
||||||
|
self._insertion_point = None
|
||||||
|
|
||||||
|
# This is ast.NodeVisitor/NodeTransformer from CPython, modified
|
||||||
|
# to add code_ prefix.
|
||||||
|
def code_visit(self, node):
|
||||||
|
method = "code_visit_" + node.__class__.__name__
|
||||||
|
visitor = getattr(self, method, self.code_generic_visit)
|
||||||
|
return visitor(node)
|
||||||
|
|
||||||
|
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
||||||
|
# to update self._insertion_point.
|
||||||
|
def code_generic_visit(self, node):
|
||||||
|
for field, old_value in ast.iter_fields(node):
|
||||||
|
old_value = getattr(node, field, None)
|
||||||
|
if isinstance(old_value, list):
|
||||||
|
prev_insertion_point = self._insertion_point
|
||||||
|
new_values = []
|
||||||
|
if field in ("body", "orelse", "finalbody"):
|
||||||
|
self._insertion_point = new_values
|
||||||
|
for value in old_value:
|
||||||
|
if isinstance(value, ast.AST):
|
||||||
|
value = self.code_visit(value)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
elif not isinstance(value, ast.AST):
|
||||||
|
new_values.extend(value)
|
||||||
|
continue
|
||||||
|
new_values.append(value)
|
||||||
|
old_value[:] = new_values
|
||||||
|
self._insertion_point = prev_insertion_point
|
||||||
|
elif isinstance(old_value, ast.AST):
|
||||||
|
new_node = self.code_visit(old_value)
|
||||||
|
if new_node is None:
|
||||||
|
delattr(node, field)
|
||||||
|
else:
|
||||||
|
setattr(node, field, new_node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def code_visit_Name(self, node):
|
||||||
|
if isinstance(node.ctx, ast.Store):
|
||||||
|
if (node.id in self.local_namespace
|
||||||
|
and isinstance(self.local_namespace[node.id],
|
||||||
|
MangledName)):
|
||||||
|
new_name = self.local_namespace[node.id].s
|
||||||
|
else:
|
||||||
|
new_name = new_mangled_name(self.in_use_names, node.id)
|
||||||
|
self.local_namespace[node.id] = MangledName(new_name)
|
||||||
|
node.id = new_name
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
obj = self.local_namespace[node.id]
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
obj = self.global_namespace[node.id]
|
||||||
|
except KeyError:
|
||||||
|
raise NameError("name '{}' is not defined".format(node.id))
|
||||||
|
if isinstance(obj, MangledName):
|
||||||
|
node.id = obj.s
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return value_to_ast(obj)
|
||||||
|
except NotASTRepresentable:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Static object cannot be used here")
|
||||||
|
|
||||||
|
def code_visit_Attribute(self, node):
|
||||||
|
# There are two cases of attributes:
|
||||||
|
# 1. static object attributes, e.g. self.foo
|
||||||
|
# 2. dynamic expression attributes, e.g.
|
||||||
|
# (Fraction(1, 2) + x).numerator
|
||||||
|
# Static object resolution has no side effects so we try it first.
|
||||||
|
try:
|
||||||
|
obj = self.static_visit(node.value)
|
||||||
|
except:
|
||||||
|
self.code_generic_visit(node)
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
key = (id(obj), node.attr)
|
||||||
|
try:
|
||||||
|
attr_info = self.attribute_namespace[key]
|
||||||
|
except KeyError:
|
||||||
|
new_name = new_mangled_name(self.in_use_names, node.attr)
|
||||||
|
attr_info = AttributeInfo(obj, new_name, False)
|
||||||
|
self.attribute_namespace[key] = attr_info
|
||||||
|
if isinstance(node.ctx, ast.Store):
|
||||||
|
attr_info.read_write = True
|
||||||
|
return ast.copy_location(
|
||||||
|
ast.Name(attr_info.mangled_name, node.ctx),
|
||||||
|
node)
|
||||||
|
|
||||||
|
def code_visit_Call(self, node):
|
||||||
|
func = self.static_visit(node.func)
|
||||||
|
node.args = [self.code_visit(arg) for arg in node.args]
|
||||||
|
|
||||||
|
if is_embeddable(func):
|
||||||
|
node.func = ast.copy_location(
|
||||||
|
ast.Name(func.__name__, ast.Load()),
|
||||||
|
node)
|
||||||
|
return node
|
||||||
|
elif is_inlinable(self.core, func):
|
||||||
|
retval_name = func.k_function_info.k_function.__name__ + "_return"
|
||||||
|
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
|
||||||
|
args = [func.__self__] + node.args
|
||||||
|
inlined = get_inline(self.core,
|
||||||
|
self.attribute_namespace, self.in_use_names,
|
||||||
|
retval_name_m, self.mappers,
|
||||||
|
func.k_function_info.k_function, args)
|
||||||
|
seq = ast.copy_location(
|
||||||
|
ast.With(
|
||||||
|
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||||
|
ctx=ast.Load()),
|
||||||
|
optional_vars=None)],
|
||||||
|
body=inlined.body),
|
||||||
|
node)
|
||||||
|
self._insertion_point.append(seq)
|
||||||
|
return ast.copy_location(ast.Name(retval_name_m, ast.Load()),
|
||||||
|
node)
|
||||||
|
else:
|
||||||
|
arg1 = ast.copy_location(ast.Str("rpc"), node)
|
||||||
|
arg2 = ast.copy_location(
|
||||||
|
value_to_ast(self.mappers.rpc.encode(func)), node)
|
||||||
|
node.args[0:0] = [arg1, arg2]
|
||||||
|
node.func = ast.copy_location(
|
||||||
|
ast.Name("syscall", ast.Load()), node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def code_visit_Return(self, node):
|
||||||
|
self.code_generic_visit(node)
|
||||||
|
if self.retval_name is None:
|
||||||
|
return node
|
||||||
|
else:
|
||||||
|
return ast.copy_location(
|
||||||
|
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
|
||||||
|
value=node.value),
|
||||||
|
node)
|
||||||
|
|
||||||
|
def code_visit_Expr(self, node):
|
||||||
|
if isinstance(node.value, ast.Str):
|
||||||
|
# Strip docstrings. This also removes strings appearing in the
|
||||||
|
# middle of the code, but they are nops.
|
||||||
|
return None
|
||||||
|
self.code_generic_visit(node)
|
||||||
|
if isinstance(node.value, ast.Name):
|
||||||
|
# Remove Expr nodes that contain only a name, likely due to
|
||||||
|
# function call inlining. Such nodes that were originally in the
|
||||||
|
# code are also removed, but this does not affect the semantics of
|
||||||
|
# the code as they are nops.
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
|
def encode_exception(self, e):
|
||||||
|
exception_class = self.static_visit(e)
|
||||||
|
if not inspect.isclass(exception_class):
|
||||||
|
raise NotImplementedError("Exception type must be a class")
|
||||||
|
if issubclass(exception_class, core_language.RuntimeException):
|
||||||
|
exception_id = exception_class.eid
|
||||||
|
else:
|
||||||
|
exception_id = self.mappers.exception.encode(exception_class)
|
||||||
|
return ast.copy_location(
|
||||||
|
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||||
|
args=[value_to_ast(exception_id)],
|
||||||
|
keywords=[], starargs=None, kwargs=None),
|
||||||
|
e)
|
||||||
|
|
||||||
|
def code_visit_Raise(self, node):
|
||||||
|
if node.cause is not None:
|
||||||
|
raise NotImplementedError("Exception causes are not supported")
|
||||||
|
if node.exc is not None:
|
||||||
|
node.exc = self.encode_exception(node.exc)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def code_visit_ExceptHandler(self, node):
|
||||||
|
if node.name is not None:
|
||||||
|
raise NotImplementedError("'as target' is not supported")
|
||||||
|
if node.type is not None:
|
||||||
|
if isinstance(node.type, ast.Tuple):
|
||||||
|
node.type.elts = [self.encode_exception(e)
|
||||||
|
for e in node.type.elts]
|
||||||
|
else:
|
||||||
|
node.type = self.encode_exception(node.type)
|
||||||
|
self.code_generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def code_visit_FunctionDef(self, node):
|
||||||
|
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
||||||
|
kw_defaults=[], kwarg=None, defaults=[])
|
||||||
|
node.decorator_list = []
|
||||||
|
self.code_generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def static_visit(self, node):
|
||||||
|
method = "static_visit_" + node.__class__.__name__
|
||||||
|
visitor = getattr(self, method)
|
||||||
|
return visitor(node)
|
||||||
|
|
||||||
|
def static_visit_Name(self, node):
|
||||||
|
try:
|
||||||
|
obj = self.local_namespace[node.id]
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
obj = self.global_namespace[node.id]
|
||||||
|
except KeyError:
|
||||||
|
raise NameError("name '{}' is not defined".format(node.id))
|
||||||
|
if isinstance(obj, MangledName):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only a static object can be used here")
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def static_visit_Attribute(self, node):
|
||||||
|
value = self.static_visit(node.value)
|
||||||
|
return getattr(value, node.attr)
|
||||||
|
|
||||||
|
|
||||||
|
class HostObjectMapper:
|
||||||
def __init__(self, first_encoding=0):
|
def __init__(self, first_encoding=0):
|
||||||
self._next_encoding = first_encoding
|
self._next_encoding = first_encoding
|
||||||
# id(object) -> (encoding, object)
|
# id(object) -> (encoding, object)
|
||||||
|
@ -31,334 +371,37 @@ class _HostObjectMapper:
|
||||||
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
return {encoding: obj for i, (encoding, obj) in self._d.items()}
|
||||||
|
|
||||||
|
|
||||||
_UserVariable = namedtuple("_UserVariable", "name")
|
def inline(core, k_function, k_args, k_kwargs):
|
||||||
|
if k_kwargs:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Keyword arguments are not supported for kernels")
|
||||||
|
|
||||||
|
# OrderedDict prevents non-determinism in parameter init
|
||||||
|
attribute_namespace = OrderedDict()
|
||||||
|
in_use_names = {func.__name__ for func in embeddable_funcs}
|
||||||
|
mappers = types.SimpleNamespace(
|
||||||
|
rpc=HostObjectMapper(),
|
||||||
|
exception=HostObjectMapper(core_language.first_user_eid)
|
||||||
|
)
|
||||||
|
func_def = get_inline(
|
||||||
|
core=core,
|
||||||
|
attribute_namespace=attribute_namespace,
|
||||||
|
in_use_names=in_use_names,
|
||||||
|
retval_name=None,
|
||||||
|
mappers=mappers,
|
||||||
|
func=k_function,
|
||||||
|
args=k_args)
|
||||||
|
|
||||||
def _is_kernel_attr(value, attr):
|
|
||||||
return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split()
|
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.rpc_mapper = _HostObjectMapper()
|
|
||||||
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
|
|
||||||
self.kernel_attr_init = []
|
|
||||||
|
|
||||||
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
|
|
||||||
# -> _UserVariable(name) / complex object
|
|
||||||
self._to_inlined = dict()
|
|
||||||
# inlined_name -> use_count
|
|
||||||
self._use_count = dict()
|
|
||||||
# reserved names
|
|
||||||
for kg in core_language.kernel_globals:
|
|
||||||
self._use_count[kg] = 1
|
|
||||||
for name in ("bool", "int", "round", "int64", "round64", "float",
|
|
||||||
"Fraction", "array", "Quantity", "EncodedException",
|
|
||||||
"range"):
|
|
||||||
self._use_count[name] = 1
|
|
||||||
|
|
||||||
# Complex objects in the namespace of functions can be used in two ways:
|
|
||||||
# 1. Calling a method on them (which gets inlined or RPCd)
|
|
||||||
# 2. Getting or setting attributes (which are turned into local variables)
|
|
||||||
# They are needed to implement "self", which is the only supported use
|
|
||||||
# case.
|
|
||||||
def register_complex_object(self, obj, func_name, ref_name,
|
|
||||||
complex_object):
|
|
||||||
assert(not isinstance(complex_object, ast.AST))
|
|
||||||
self._to_inlined[(id(obj), func_name, ref_name)] = complex_object
|
|
||||||
|
|
||||||
def new_name(self, base_name):
|
|
||||||
if base_name[-1].isdigit():
|
|
||||||
base_name += "_"
|
|
||||||
if base_name in self._use_count:
|
|
||||||
r = base_name + str(self._use_count[base_name])
|
|
||||||
self._use_count[base_name] += 1
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
self._use_count[base_name] = 1
|
|
||||||
return base_name
|
|
||||||
|
|
||||||
def resolve_name(self, obj, func_name, ref_name, store):
|
|
||||||
key = (id(obj), func_name, ref_name)
|
|
||||||
try:
|
|
||||||
return self._to_inlined[key]
|
|
||||||
except KeyError:
|
|
||||||
if store:
|
|
||||||
ival = _UserVariable(self.new_name(ref_name))
|
|
||||||
self._to_inlined[key] = ival
|
|
||||||
return ival
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return inspect.getmodule(obj).__dict__[ref_name]
|
|
||||||
except KeyError:
|
|
||||||
return getattr(builtins, ref_name)
|
|
||||||
|
|
||||||
def resolve_attr(self, value, attr):
|
|
||||||
if _is_kernel_attr(value, attr):
|
|
||||||
key = (id(value), attr)
|
|
||||||
try:
|
|
||||||
ival = self._to_inlined[key]
|
|
||||||
assert(isinstance(ival, _UserVariable))
|
|
||||||
except KeyError:
|
|
||||||
iname = self.new_name(attr)
|
|
||||||
ival = _UserVariable(iname)
|
|
||||||
self._to_inlined[key] = ival
|
|
||||||
a = value_to_ast(getattr(value, attr))
|
|
||||||
if a is None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Cannot represent initial value"
|
|
||||||
" of kernel attribute")
|
|
||||||
self.kernel_attr_init.append(ast.Assign(
|
|
||||||
[ast.Name(iname, ast.Store())], a))
|
|
||||||
return ival
|
|
||||||
else:
|
|
||||||
return getattr(value, attr)
|
|
||||||
|
|
||||||
def resolve_constant(self, obj, func_name, node):
|
|
||||||
if isinstance(node, ast.Name):
|
|
||||||
c = self.resolve_name(obj, func_name, node.id, False)
|
|
||||||
if isinstance(c, _UserVariable):
|
|
||||||
raise ValueError("Not a constant")
|
|
||||||
return c
|
|
||||||
elif isinstance(node, ast.Attribute):
|
|
||||||
value = self.resolve_constant(obj, func_name, node.value)
|
|
||||||
if _is_kernel_attr(value, node.attr):
|
|
||||||
raise ValueError("Not a constant")
|
|
||||||
return getattr(value, node.attr)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
_embeddable_funcs = (
|
|
||||||
core_language.delay, core_language.at, core_language.now,
|
|
||||||
core_language.time_to_cycles, core_language.cycles_to_time,
|
|
||||||
core_language.syscall,
|
|
||||||
range, bool, int, float, round,
|
|
||||||
core_language.int64, core_language.round64, core_language.array,
|
|
||||||
Fraction, units.Quantity, core_language.EncodedException
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_embeddable(func):
|
|
||||||
for ef in _embeddable_funcs:
|
|
||||||
if func is ef:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_inlinable(core, func):
|
|
||||||
if hasattr(func, "k_function_info"):
|
|
||||||
if func.k_function_info.core_name == "":
|
|
||||||
return True # portable function
|
|
||||||
if getattr(func.__self__, func.k_function_info.core_name) is core:
|
|
||||||
return True # kernel function for the same core device
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceReplacer(ast.NodeVisitor):
|
|
||||||
def __init__(self, core, rm, obj, func_name, retval_name):
|
|
||||||
self.core = core
|
|
||||||
self.rm = rm
|
|
||||||
self.obj = obj
|
|
||||||
self.func_name = func_name
|
|
||||||
self.retval_name = retval_name
|
|
||||||
self._insertion_point = None
|
|
||||||
|
|
||||||
# This is ast.NodeTransformer.generic_visit from CPython, modified
|
|
||||||
# to update self._insertion_point.
|
|
||||||
def generic_visit(self, node):
|
|
||||||
for field, old_value in ast.iter_fields(node):
|
|
||||||
old_value = getattr(node, field, None)
|
|
||||||
if isinstance(old_value, list):
|
|
||||||
prev_insertion_point = self._insertion_point
|
|
||||||
new_values = []
|
|
||||||
if field in ("body", "orelse", "finalbody"):
|
|
||||||
self._insertion_point = new_values
|
|
||||||
for value in old_value:
|
|
||||||
if isinstance(value, ast.AST):
|
|
||||||
value = self.visit(value)
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
elif not isinstance(value, ast.AST):
|
|
||||||
new_values.extend(value)
|
|
||||||
continue
|
|
||||||
new_values.append(value)
|
|
||||||
old_value[:] = new_values
|
|
||||||
self._insertion_point = prev_insertion_point
|
|
||||||
elif isinstance(old_value, ast.AST):
|
|
||||||
new_node = self.visit(old_value)
|
|
||||||
if new_node is None:
|
|
||||||
delattr(node, field)
|
|
||||||
else:
|
|
||||||
setattr(node, field, new_node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_Name(self, node):
|
|
||||||
store = isinstance(node.ctx, ast.Store)
|
|
||||||
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
|
|
||||||
if isinstance(ival, _UserVariable):
|
|
||||||
newnode = ast.Name(ival.name, node.ctx)
|
|
||||||
else:
|
|
||||||
if store:
|
|
||||||
raise NotImplementedError("Cannot assign to this object")
|
|
||||||
newnode = value_to_ast(ival)
|
|
||||||
return ast.copy_location(newnode, node)
|
|
||||||
|
|
||||||
def _resolve_attribute(self, node):
|
|
||||||
if isinstance(node, ast.Name):
|
|
||||||
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False)
|
|
||||||
if isinstance(ival, _UserVariable):
|
|
||||||
return ast.Name(ival.name, ast.Load())
|
|
||||||
else:
|
|
||||||
return ival
|
|
||||||
elif isinstance(node, ast.Attribute):
|
|
||||||
value = self._resolve_attribute(node.value)
|
|
||||||
if isinstance(value, ast.AST):
|
|
||||||
node.value = value
|
|
||||||
return node
|
|
||||||
else:
|
|
||||||
return self.rm.resolve_attr(value, node.attr)
|
|
||||||
else:
|
|
||||||
return self.visit(node)
|
|
||||||
|
|
||||||
def visit_Attribute(self, node):
|
|
||||||
ival = self._resolve_attribute(node)
|
|
||||||
if isinstance(ival, ast.AST):
|
|
||||||
newnode = deepcopy(ival)
|
|
||||||
elif isinstance(ival, _UserVariable):
|
|
||||||
newnode = ast.Name(ival.name, node.ctx)
|
|
||||||
else:
|
|
||||||
newnode = value_to_ast(ival)
|
|
||||||
return ast.copy_location(newnode, node)
|
|
||||||
|
|
||||||
def visit_Call(self, node):
|
|
||||||
func = self.rm.resolve_constant(self.obj, self.func_name, node.func)
|
|
||||||
new_args = [self.visit(arg) for arg in node.args]
|
|
||||||
|
|
||||||
if _is_embeddable(func):
|
|
||||||
new_func = ast.Name(func.__name__, ast.Load())
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Call(func=new_func, args=new_args,
|
|
||||||
keywords=[], starargs=None, kwargs=None),
|
|
||||||
node)
|
|
||||||
elif _is_inlinable(self.core, func):
|
|
||||||
retval_name = self.rm.new_name(
|
|
||||||
func.k_function_info.k_function.__name__ + "_return")
|
|
||||||
args = [func.__self__] + new_args
|
|
||||||
inlined, _, _ = inline(self.core, func.k_function_info.k_function,
|
|
||||||
args, dict(), self.rm, retval_name)
|
|
||||||
self._insertion_point.append(ast.With(
|
|
||||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
|
||||||
ctx=ast.Load()),
|
|
||||||
optional_vars=None)],
|
|
||||||
body=inlined.body))
|
|
||||||
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
|
|
||||||
else:
|
|
||||||
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))]
|
|
||||||
args += new_args
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Call(func=ast.Name("syscall", ast.Load()),
|
|
||||||
args=args, keywords=[], starargs=None, kwargs=None),
|
|
||||||
node)
|
|
||||||
|
|
||||||
def visit_Return(self, node):
|
|
||||||
self.generic_visit(node)
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Assign(targets=[ast.Name(self.retval_name, ast.Store())],
|
|
||||||
value=node.value),
|
|
||||||
node)
|
|
||||||
|
|
||||||
def visit_Expr(self, node):
|
|
||||||
if isinstance(node.value, ast.Str):
|
|
||||||
# Strip docstrings. This also removes strings appearing in the
|
|
||||||
# middle of the code, but they are nops.
|
|
||||||
return None
|
|
||||||
self.generic_visit(node)
|
|
||||||
if isinstance(node.value, ast.Name):
|
|
||||||
# Remove Expr nodes that contain only a name, likely due to
|
|
||||||
# function call inlining. Such nodes that were originally in the
|
|
||||||
# code are also removed, but this does not affect the semantics of
|
|
||||||
# the code as they are nops.
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
|
||||||
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
|
|
||||||
kw_defaults=[], kwarg=None, defaults=[])
|
|
||||||
node.decorator_list = []
|
|
||||||
self.generic_visit(node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def _encode_exception(self, e):
|
|
||||||
exception_class = self.rm.resolve_constant(self.obj, self.func_name, e)
|
|
||||||
if not inspect.isclass(exception_class):
|
|
||||||
raise NotImplementedError("Exception type must be a class")
|
|
||||||
if issubclass(exception_class, core_language.RuntimeException):
|
|
||||||
exception_id = exception_class.eid
|
|
||||||
else:
|
|
||||||
exception_id = self.rm.exception_mapper.encode(exception_class)
|
|
||||||
return ast.copy_location(
|
|
||||||
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
|
||||||
args=[value_to_ast(exception_id)],
|
|
||||||
keywords=[], starargs=None, kwargs=None),
|
|
||||||
e)
|
|
||||||
|
|
||||||
def visit_Raise(self, node):
|
|
||||||
if node.cause is not None:
|
|
||||||
raise NotImplementedError("Exception causes are not supported")
|
|
||||||
if node.exc is not None:
|
|
||||||
node.exc = self._encode_exception(node.exc)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_ExceptHandler(self, node):
|
|
||||||
if node.name is not None:
|
|
||||||
raise NotImplementedError("'as target' is not supported")
|
|
||||||
if node.type is not None:
|
|
||||||
if isinstance(node.type, ast.Tuple):
|
|
||||||
node.type.elts = [self._encode_exception(e) for e in node.type.elts]
|
|
||||||
else:
|
|
||||||
node.type = self._encode_exception(node.type)
|
|
||||||
self.generic_visit(node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_function_params(func_def, k_args, k_kwargs, rm):
|
|
||||||
obj = k_args[0]
|
|
||||||
func_name = func_def.name
|
|
||||||
param_init = []
|
param_init = []
|
||||||
for arg_ast, arg_value in zip(func_def.args.args, k_args):
|
for (_, attr), attr_info in attribute_namespace.items():
|
||||||
if isinstance(arg_value, ast.AST):
|
value = getattr(attr_info.obj, attr)
|
||||||
value = arg_value
|
value = ast.copy_location(value_to_ast(value), func_def)
|
||||||
else:
|
target = ast.copy_location(ast.Name(attr_info.mangled_name,
|
||||||
try:
|
ast.Store()),
|
||||||
value = value_to_ast(arg_value)
|
func_def)
|
||||||
except NotASTRepresentable:
|
assign = ast.copy_location(ast.Assign([target], value),
|
||||||
value = None
|
func_def)
|
||||||
if value is None:
|
param_init.append(assign)
|
||||||
rm.register_complex_object(obj, func_name, arg_ast.arg, arg_value)
|
|
||||||
else:
|
|
||||||
uservar = rm.resolve_name(obj, func_name, arg_ast.arg, True)
|
|
||||||
target = ast.Name(uservar.name, ast.Store())
|
|
||||||
param_init.append(ast.Assign(targets=[target], value=value))
|
|
||||||
return param_init
|
|
||||||
|
|
||||||
|
|
||||||
def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
|
|
||||||
init_kernel_attr = rm is None
|
|
||||||
if rm is None:
|
|
||||||
rm = _ReferenceManager()
|
|
||||||
|
|
||||||
func_def = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
|
|
||||||
|
|
||||||
param_init = _initialize_function_params(func_def, k_args, k_kwargs, rm)
|
|
||||||
|
|
||||||
obj = k_args[0]
|
|
||||||
func_name = func_def.name
|
|
||||||
rr = _ReferenceReplacer(core, rm, obj, func_name, retval_name)
|
|
||||||
rr.visit(func_def)
|
|
||||||
|
|
||||||
func_def.body[0:0] = param_init
|
func_def.body[0:0] = param_init
|
||||||
if init_kernel_attr:
|
|
||||||
func_def.body[0:0] = rm.kernel_attr_init
|
|
||||||
|
|
||||||
return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()
|
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
||||||
|
|
Loading…
Reference in New Issue