From 8b552134a0299c9bc465c62eee125f7fde9e35b8 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Tue, 18 Nov 2014 13:40:15 -0800 Subject: [PATCH] transforms/inline: keyword argument and default value support --- artiq/transforms/inline.py | 103 +++++++++++++++++++++++++++++++------ 1 file changed, 86 insertions(+), 17 deletions(-) diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index 105e895e0..dd67b2239 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -6,6 +6,7 @@ import builtins from fractions import Fraction from collections import OrderedDict from functools import partial +from itertools import zip_longest, chain from artiq.language import core as core_language from artiq.language import units @@ -71,17 +72,79 @@ class GlobalNamespace: return getattr(builtins, item) +class UndefinedArg: + pass + + +def get_function_args(func_args, func_tr, args, kwargs): + # OrderedDict prevents non-determinism in argument init + r = OrderedDict() + + # Process positional arguments. Any missing positional argument values + # are set to UndefinedArg. + for arg, arg_value in zip_longest(func_args.args, args, + fillvalue=UndefinedArg): + if arg is UndefinedArg: + raise TypeError("Got too many positional arguments") + if arg.arg in r: + raise SyntaxError("Duplicate argument '{}' in function definition" + .format(arg.arg)) + r[arg.arg] = arg_value + + # Process keyword arguments. Any missing keyword-only argument values + # are set to UndefinedArg. + valid_arg_names = {arg.arg for arg in + chain(func_args.args, func_args.kwonlyargs)} + for arg in func_args.kwonlyargs: + if arg.arg in r: + raise SyntaxError("Duplicate argument '{}' in function definition" + .format(arg.arg)) + r[arg.arg] = UndefinedArg + for arg_name, arg_value in kwargs.items(): + if arg_name not in valid_arg_names: + raise TypeError("Got unexpected keyword argument '{}'" + .format(arg_name)) + if r[arg_name] is not UndefinedArg: + raise TypeError("Got multiple values for argument '{}'" + .format(arg_name)) + r[arg_name] = arg_value + + # Replace any UndefinedArg positional arguments with the default value, + # when provided. + for arg, default in zip(func_args.args[len(func_args.defaults):], + func_args.defaults): + if r[arg.arg] is UndefinedArg: + r[arg.arg] = func_tr.code_visit(default) + # Same with keyword-only arguments. + for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults): + if default is not None and r[arg.arg] is UndefinedArg: + r[arg.arg] = func_tr.code_visit(default) + + # Check that no argument was left undefined. + missing_arguments = ["'"+arg+"'" for arg, value in r.items() + if value is UndefinedArg] + if missing_arguments: + raise TypeError("Missing argument(s): " + " ".join(missing_arguments)) + + return r + + +# args/kwargs can contain values or AST nodes def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers, - func, args): + func, args, kwargs): global_namespace = GlobalNamespace(func) func_tr = Function(core, global_namespace, attribute_namespace, in_use_names, retval_name, mappers) func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0] - param_init = [] - # initialize arguments - for arg_ast, arg_value in zip(func_def.args.args, args): + # Initialize arguments. + # The local namespace is empty so code_visit will always resolve + # using the global namespace. + arg_init = [] + arg_name_map = [] + arg_dict = get_function_args(func_def.args, func_tr, args, kwargs) + for arg_name, arg_value in arg_dict.items(): if isinstance(arg_value, ast.AST): value = arg_value else: @@ -91,20 +154,25 @@ def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers, value = None if value is None: # static object - func_tr.local_namespace[arg_ast.arg] = arg_value + func_tr.local_namespace[arg_name] = arg_value else: # set parameter value with "name = value" # assignment at beginning of function - new_name = new_mangled_name(in_use_names, arg_ast.arg) - func_tr.local_namespace[arg_ast.arg] = MangledName(new_name) + new_name = new_mangled_name(in_use_names, arg_name) + arg_name_map.append((arg_name, new_name)) target = ast.copy_location(ast.Name(new_name, ast.Store()), func_def) assign = ast.copy_location(ast.Assign([target], value), func_def) - param_init.append(assign) + arg_init.append(assign) + # Commit arguments to the local namespace at the end to handle cases + # such as f(x, y=x) (for the default value of y, x must be resolved + # using the global namespace). + for arg_name, mangled_name in arg_name_map: + func_tr.local_namespace[arg_name] = MangledName(mangled_name) func_def = func_tr.code_visit(func_def) - func_def.body[0:0] = param_init + func_def.body[0:0] = arg_init return func_def @@ -231,6 +299,8 @@ class Function: def code_visit_Call(self, node): func = self.static_visit(node.func) node.args = [self.code_visit(arg) for arg in node.args] + for kw in node.keywords: + kw.value = self.code_visit(kw.value) if is_embeddable(func): node.func = ast.copy_location( @@ -241,10 +311,12 @@ class Function: retval_name = func.k_function_info.k_function.__name__ + "_return" retval_name_m = new_mangled_name(self.in_use_names, retval_name) args = [func.__self__] + node.args + kwargs = {kw.arg: kw.value for kw in node.keywords} inlined = get_inline(self.core, self.attribute_namespace, self.in_use_names, retval_name_m, self.mappers, - func.k_function_info.k_function, args) + func.k_function_info.k_function, + args, kwargs) seq = ast.copy_location( ast.With( items=[ast.withitem(context_expr=ast.Name(id="sequential", @@ -419,11 +491,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node): def inline(core, k_function, k_args, k_kwargs): - if k_kwargs: - raise NotImplementedError( - "Keyword arguments are not supported for kernels") - - # OrderedDict prevents non-determinism in parameter init + # OrderedDict prevents non-determinism in attribute init attribute_namespace = OrderedDict() in_use_names = {func.__name__ for func in embeddable_funcs} mappers = types.SimpleNamespace( @@ -437,10 +505,11 @@ def inline(core, k_function, k_args, k_kwargs): retval_name=None, mappers=mappers, func=k_function, - args=k_args) + args=k_args, + kwargs=k_kwargs) func_def.body[0:0] = get_attr_init(attribute_namespace, func_def) func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc, - func_def) + func_def) return func_def, mappers.rpc.get_map(), mappers.exception.get_map()