forked from M-Labs/artiq
1
0
Fork 0

transforms/inline: support attributes on user variables/calls

This commit is contained in:
Sebastien Bourdeauducq 2014-10-08 18:01:15 +08:00
parent e22301ea05
commit 2920ac85d2
1 changed files with 97 additions and 53 deletions

View File

@ -33,62 +33,70 @@ class _HostObjectMapper:
_UserVariable = namedtuple("_UserVariable", "name") _UserVariable = namedtuple("_UserVariable", "name")
def _is_kernel_attr(value, attr):
return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split()
class _ReferenceManager: class _ReferenceManager:
def __init__(self): def __init__(self):
# (id(obj), func_name, local_name) or (id(obj), kernel_attr_name)
# -> _UserVariable(name) / ast / constant_object
self.to_inlined = dict()
# inlined_name -> use_count
self.use_count = dict()
self.rpc_mapper = _HostObjectMapper() self.rpc_mapper = _HostObjectMapper()
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid) self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
self.kernel_attr_init = [] self.kernel_attr_init = []
# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
# -> _UserVariable(name) / ast / constant_object
self._to_inlined = dict()
# inlined_name -> use_count
self._use_count = dict()
# reserved names # reserved names
for kg in core_language.kernel_globals: for kg in core_language.kernel_globals:
self.use_count[kg] = 1 self._use_count[kg] = 1
for name in ("int", "round", "int64", "round64", "float", "array", for name in ("int", "round", "int64", "round64", "float", "array",
"range", "Fraction", "Quantity", "EncodedException"): "range", "Fraction", "Quantity", "EncodedException"):
self.use_count[name] = 1 self._use_count[name] = 1
# node_or_value can be a AST node, used to inline function parameter values
# that can be simplified later through constant folding.
def register_replace(self, obj, func_name, ref_name, node_or_value):
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value
def new_name(self, base_name): def new_name(self, base_name):
if base_name[-1].isdigit(): if base_name[-1].isdigit():
base_name += "_" base_name += "_"
if base_name in self.use_count: if base_name in self._use_count:
r = base_name + str(self.use_count[base_name]) r = base_name + str(self._use_count[base_name])
self.use_count[base_name] += 1 self._use_count[base_name] += 1
return r return r
else: else:
self.use_count[base_name] = 1 self._use_count[base_name] = 1
return base_name return base_name
def get(self, obj, func_name, ref): def resolve_name(self, obj, func_name, ref_name, store):
if isinstance(ref, ast.Name): key = (id(obj), func_name, ref_name)
key = (id(obj), func_name, ref.id)
try: try:
return self.to_inlined[key] return self._to_inlined[key]
except KeyError: except KeyError:
if isinstance(ref.ctx, ast.Store): if store:
ival = _UserVariable(self.new_name(ref.id)) ival = _UserVariable(self.new_name(ref_name))
self.to_inlined[key] = ival self._to_inlined[key] = ival
return ival return ival
else: else:
try: try:
return inspect.getmodule(obj).__dict__[ref.id] return inspect.getmodule(obj).__dict__[ref_name]
except KeyError: except KeyError:
return getattr(builtins, ref.id) return getattr(builtins, ref_name)
elif isinstance(ref, ast.Attribute):
target = self.get(obj, func_name, ref.value) def resolve_attr(self, value, attr):
if hasattr(target, "kernel_attr") and ref.attr in target.kernel_attr.split(): if _is_kernel_attr(value, attr):
key = (id(target), ref.attr) key = (id(value), attr)
try: try:
ival = self.to_inlined[key] ival = self._to_inlined[key]
assert(isinstance(ival, _UserVariable)) assert(isinstance(ival, _UserVariable))
except KeyError: except KeyError:
iname = self.new_name(ref.attr) iname = self.new_name(attr)
ival = _UserVariable(iname) ival = _UserVariable(iname)
self.to_inlined[key] = ival self._to_inlined[key] = ival
a = value_to_ast(getattr(target, ref.attr)) a = value_to_ast(getattr(value, attr))
if a is None: if a is None:
raise NotImplementedError( raise NotImplementedError(
"Cannot represent initial value" "Cannot represent initial value"
@ -97,7 +105,19 @@ class _ReferenceManager:
[ast.Name(iname, ast.Store())], a)) [ast.Name(iname, ast.Store())], a))
return ival return ival
else: else:
return getattr(target, ref.attr) return getattr(value, attr)
def resolve_constant(self, obj, func_name, node):
if isinstance(node, ast.Name):
c = self.resolve_name(obj, func_name, node.id, False)
if isinstance(c, (_UserVariable, ast.AST)):
raise ValueError("Not a constant")
return c
elif isinstance(node, ast.Attribute):
value = self.resolve_constant(obj, func_name, node.value)
if _is_kernel_attr(value, node.attr):
raise ValueError("Not a constant")
return getattr(value, node.attr)
else: else:
raise NotImplementedError raise NotImplementedError
@ -156,9 +176,9 @@ class _ReferenceReplacer(ast.NodeVisitor):
setattr(node, field, new_node) setattr(node, field, new_node)
return node return node
def visit_ref(self, node): def visit_Name(self, node):
store = isinstance(node.ctx, ast.Store) store = isinstance(node.ctx, ast.Store)
ival = self.rm.get(self.obj, self.func_name, node) ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
if isinstance(ival, _UserVariable): if isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx) newnode = ast.Name(ival.name, node.ctx)
elif isinstance(ival, ast.AST): elif isinstance(ival, ast.AST):
@ -175,11 +195,34 @@ class _ReferenceReplacer(ast.NodeVisitor):
"Cannot represent inlined value") "Cannot represent inlined value")
return ast.copy_location(newnode, node) return ast.copy_location(newnode, node)
visit_Name = visit_ref def _resolve_attribute(self, node):
visit_Attribute = visit_ref if isinstance(node, ast.Name):
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False)
if isinstance(ival, _UserVariable):
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
else:
return ival
elif isinstance(node, ast.Attribute):
value = self._resolve_attribute(node.value)
if isinstance(value, ast.AST):
node.value = value
return node
else:
return self.rm.resolve_attr(value, node.attr)
else:
return self.visit(node)
def visit_Attribute(self, node):
ival = self._resolve_attribute(node)
if isinstance(ival, ast.AST):
return ival
elif isinstance(ival, _UserVariable):
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
else:
return value_to_ast(ival)
def visit_Call(self, node): def visit_Call(self, node):
func = self.rm.get(self.obj, self.func_name, node.func) func = self.rm.resolve_constant(self.obj, self.func_name, node.func)
new_args = [self.visit(arg) for arg in node.args] new_args = [self.visit(arg) for arg in node.args]
if _is_embeddable(func): if _is_embeddable(func):
@ -240,7 +283,7 @@ class _ReferenceReplacer(ast.NodeVisitor):
return node return node
def _encode_exception(self, e): def _encode_exception(self, e):
exception_class = self.rm.get(self.obj, self.func_name, e) exception_class = self.rm.resolve_constant(self.obj, self.func_name, e)
if not inspect.isclass(exception_class): if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class") raise NotImplementedError("Exception type must be a class")
if issubclass(exception_class, core_language.RuntimeException): if issubclass(exception_class, core_language.RuntimeException):
@ -301,9 +344,10 @@ def _initialize_function_params(func_def, k_args, k_kwargs, rm):
for arg_ast, arg_value in zip(func_def.args.args, k_args): for arg_ast, arg_value in zip(func_def.args.args, k_args):
arg_name = arg_ast.arg arg_name = arg_ast.arg
if arg_name in rop: if arg_name in rop:
rm.to_inlined[(id(obj), func_name, arg_name)] = arg_value rm.register_replace(obj, func_name, arg_name, arg_value)
else: else:
target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store())) uservar = rm.resolve_name(obj, func_name, arg_name, True)
target = ast.Name(uservar.name, ast.Store())
value = value_to_ast(arg_value) value = value_to_ast(arg_value)
param_init.append(ast.Assign(targets=[target], value=value)) param_init.append(ast.Assign(targets=[target], value=value))
return param_init return param_init