forked from M-Labs/artiq
transforms/inline: keyword argument and default value support
This commit is contained in:
parent
abae5c6728
commit
8b552134a0
|
@ -6,6 +6,7 @@ import builtins
|
|||
from fractions import Fraction
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from itertools import zip_longest, chain
|
||||
|
||||
from artiq.language import core as core_language
|
||||
from artiq.language import units
|
||||
|
@ -71,17 +72,79 @@ class GlobalNamespace:
|
|||
return getattr(builtins, item)
|
||||
|
||||
|
||||
class UndefinedArg:
|
||||
pass
|
||||
|
||||
|
||||
def get_function_args(func_args, func_tr, args, kwargs):
|
||||
# OrderedDict prevents non-determinism in argument init
|
||||
r = OrderedDict()
|
||||
|
||||
# Process positional arguments. Any missing positional argument values
|
||||
# are set to UndefinedArg.
|
||||
for arg, arg_value in zip_longest(func_args.args, args,
|
||||
fillvalue=UndefinedArg):
|
||||
if arg is UndefinedArg:
|
||||
raise TypeError("Got too many positional arguments")
|
||||
if arg.arg in r:
|
||||
raise SyntaxError("Duplicate argument '{}' in function definition"
|
||||
.format(arg.arg))
|
||||
r[arg.arg] = arg_value
|
||||
|
||||
# Process keyword arguments. Any missing keyword-only argument values
|
||||
# are set to UndefinedArg.
|
||||
valid_arg_names = {arg.arg for arg in
|
||||
chain(func_args.args, func_args.kwonlyargs)}
|
||||
for arg in func_args.kwonlyargs:
|
||||
if arg.arg in r:
|
||||
raise SyntaxError("Duplicate argument '{}' in function definition"
|
||||
.format(arg.arg))
|
||||
r[arg.arg] = UndefinedArg
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
if arg_name not in valid_arg_names:
|
||||
raise TypeError("Got unexpected keyword argument '{}'"
|
||||
.format(arg_name))
|
||||
if r[arg_name] is not UndefinedArg:
|
||||
raise TypeError("Got multiple values for argument '{}'"
|
||||
.format(arg_name))
|
||||
r[arg_name] = arg_value
|
||||
|
||||
# Replace any UndefinedArg positional arguments with the default value,
|
||||
# when provided.
|
||||
for arg, default in zip(func_args.args[len(func_args.defaults):],
|
||||
func_args.defaults):
|
||||
if r[arg.arg] is UndefinedArg:
|
||||
r[arg.arg] = func_tr.code_visit(default)
|
||||
# Same with keyword-only arguments.
|
||||
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
|
||||
if default is not None and r[arg.arg] is UndefinedArg:
|
||||
r[arg.arg] = func_tr.code_visit(default)
|
||||
|
||||
# Check that no argument was left undefined.
|
||||
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
|
||||
if value is UndefinedArg]
|
||||
if missing_arguments:
|
||||
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
|
||||
|
||||
return r
|
||||
|
||||
|
||||
# args/kwargs can contain values or AST nodes
|
||||
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
||||
func, args):
|
||||
func, args, kwargs):
|
||||
global_namespace = GlobalNamespace(func)
|
||||
func_tr = Function(core,
|
||||
global_namespace, attribute_namespace, in_use_names,
|
||||
retval_name, mappers)
|
||||
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
|
||||
|
||||
param_init = []
|
||||
# initialize arguments
|
||||
for arg_ast, arg_value in zip(func_def.args.args, args):
|
||||
# Initialize arguments.
|
||||
# The local namespace is empty so code_visit will always resolve
|
||||
# using the global namespace.
|
||||
arg_init = []
|
||||
arg_name_map = []
|
||||
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
|
||||
for arg_name, arg_value in arg_dict.items():
|
||||
if isinstance(arg_value, ast.AST):
|
||||
value = arg_value
|
||||
else:
|
||||
|
@ -91,20 +154,25 @@ def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
|||
value = None
|
||||
if value is None:
|
||||
# static object
|
||||
func_tr.local_namespace[arg_ast.arg] = arg_value
|
||||
func_tr.local_namespace[arg_name] = arg_value
|
||||
else:
|
||||
# set parameter value with "name = value"
|
||||
# assignment at beginning of function
|
||||
new_name = new_mangled_name(in_use_names, arg_ast.arg)
|
||||
func_tr.local_namespace[arg_ast.arg] = MangledName(new_name)
|
||||
new_name = new_mangled_name(in_use_names, arg_name)
|
||||
arg_name_map.append((arg_name, new_name))
|
||||
target = ast.copy_location(ast.Name(new_name, ast.Store()),
|
||||
func_def)
|
||||
assign = ast.copy_location(ast.Assign([target], value),
|
||||
func_def)
|
||||
param_init.append(assign)
|
||||
arg_init.append(assign)
|
||||
# Commit arguments to the local namespace at the end to handle cases
|
||||
# such as f(x, y=x) (for the default value of y, x must be resolved
|
||||
# using the global namespace).
|
||||
for arg_name, mangled_name in arg_name_map:
|
||||
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
|
||||
|
||||
func_def = func_tr.code_visit(func_def)
|
||||
func_def.body[0:0] = param_init
|
||||
func_def.body[0:0] = arg_init
|
||||
return func_def
|
||||
|
||||
|
||||
|
@ -231,6 +299,8 @@ class Function:
|
|||
def code_visit_Call(self, node):
|
||||
func = self.static_visit(node.func)
|
||||
node.args = [self.code_visit(arg) for arg in node.args]
|
||||
for kw in node.keywords:
|
||||
kw.value = self.code_visit(kw.value)
|
||||
|
||||
if is_embeddable(func):
|
||||
node.func = ast.copy_location(
|
||||
|
@ -241,10 +311,12 @@ class Function:
|
|||
retval_name = func.k_function_info.k_function.__name__ + "_return"
|
||||
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
|
||||
args = [func.__self__] + node.args
|
||||
kwargs = {kw.arg: kw.value for kw in node.keywords}
|
||||
inlined = get_inline(self.core,
|
||||
self.attribute_namespace, self.in_use_names,
|
||||
retval_name_m, self.mappers,
|
||||
func.k_function_info.k_function, args)
|
||||
func.k_function_info.k_function,
|
||||
args, kwargs)
|
||||
seq = ast.copy_location(
|
||||
ast.With(
|
||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||
|
@ -419,11 +491,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
|||
|
||||
|
||||
def inline(core, k_function, k_args, k_kwargs):
|
||||
if k_kwargs:
|
||||
raise NotImplementedError(
|
||||
"Keyword arguments are not supported for kernels")
|
||||
|
||||
# OrderedDict prevents non-determinism in parameter init
|
||||
# OrderedDict prevents non-determinism in attribute init
|
||||
attribute_namespace = OrderedDict()
|
||||
in_use_names = {func.__name__ for func in embeddable_funcs}
|
||||
mappers = types.SimpleNamespace(
|
||||
|
@ -437,10 +505,11 @@ def inline(core, k_function, k_args, k_kwargs):
|
|||
retval_name=None,
|
||||
mappers=mappers,
|
||||
func=k_function,
|
||||
args=k_args)
|
||||
args=k_args,
|
||||
kwargs=k_kwargs)
|
||||
|
||||
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
||||
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
||||
func_def)
|
||||
func_def)
|
||||
|
||||
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
||||
|
|
Loading…
Reference in New Issue