forked from M-Labs/artiq
1
0
Fork 0

transforms/inline: keyword argument and default value support

This commit is contained in:
Sebastien Bourdeauducq 2014-11-18 13:40:15 -08:00
parent abae5c6728
commit 8b552134a0
1 changed files with 86 additions and 17 deletions

View File

@ -6,6 +6,7 @@ import builtins
from fractions import Fraction from fractions import Fraction
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from itertools import zip_longest, chain
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
@ -71,17 +72,79 @@ class GlobalNamespace:
return getattr(builtins, item) return getattr(builtins, item)
class UndefinedArg:
pass
def get_function_args(func_args, func_tr, args, kwargs):
# OrderedDict prevents non-determinism in argument init
r = OrderedDict()
# Process positional arguments. Any missing positional argument values
# are set to UndefinedArg.
for arg, arg_value in zip_longest(func_args.args, args,
fillvalue=UndefinedArg):
if arg is UndefinedArg:
raise TypeError("Got too many positional arguments")
if arg.arg in r:
raise SyntaxError("Duplicate argument '{}' in function definition"
.format(arg.arg))
r[arg.arg] = arg_value
# Process keyword arguments. Any missing keyword-only argument values
# are set to UndefinedArg.
valid_arg_names = {arg.arg for arg in
chain(func_args.args, func_args.kwonlyargs)}
for arg in func_args.kwonlyargs:
if arg.arg in r:
raise SyntaxError("Duplicate argument '{}' in function definition"
.format(arg.arg))
r[arg.arg] = UndefinedArg
for arg_name, arg_value in kwargs.items():
if arg_name not in valid_arg_names:
raise TypeError("Got unexpected keyword argument '{}'"
.format(arg_name))
if r[arg_name] is not UndefinedArg:
raise TypeError("Got multiple values for argument '{}'"
.format(arg_name))
r[arg_name] = arg_value
# Replace any UndefinedArg positional arguments with the default value,
# when provided.
for arg, default in zip(func_args.args[len(func_args.defaults):],
func_args.defaults):
if r[arg.arg] is UndefinedArg:
r[arg.arg] = func_tr.code_visit(default)
# Same with keyword-only arguments.
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
if default is not None and r[arg.arg] is UndefinedArg:
r[arg.arg] = func_tr.code_visit(default)
# Check that no argument was left undefined.
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
if value is UndefinedArg]
if missing_arguments:
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
return r
# args/kwargs can contain values or AST nodes
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers, def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
func, args): func, args, kwargs):
global_namespace = GlobalNamespace(func) global_namespace = GlobalNamespace(func)
func_tr = Function(core, func_tr = Function(core,
global_namespace, attribute_namespace, in_use_names, global_namespace, attribute_namespace, in_use_names,
retval_name, mappers) retval_name, mappers)
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0] func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
param_init = [] # Initialize arguments.
# initialize arguments # The local namespace is empty so code_visit will always resolve
for arg_ast, arg_value in zip(func_def.args.args, args): # using the global namespace.
arg_init = []
arg_name_map = []
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
for arg_name, arg_value in arg_dict.items():
if isinstance(arg_value, ast.AST): if isinstance(arg_value, ast.AST):
value = arg_value value = arg_value
else: else:
@ -91,20 +154,25 @@ def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
value = None value = None
if value is None: if value is None:
# static object # static object
func_tr.local_namespace[arg_ast.arg] = arg_value func_tr.local_namespace[arg_name] = arg_value
else: else:
# set parameter value with "name = value" # set parameter value with "name = value"
# assignment at beginning of function # assignment at beginning of function
new_name = new_mangled_name(in_use_names, arg_ast.arg) new_name = new_mangled_name(in_use_names, arg_name)
func_tr.local_namespace[arg_ast.arg] = MangledName(new_name) arg_name_map.append((arg_name, new_name))
target = ast.copy_location(ast.Name(new_name, ast.Store()), target = ast.copy_location(ast.Name(new_name, ast.Store()),
func_def) func_def)
assign = ast.copy_location(ast.Assign([target], value), assign = ast.copy_location(ast.Assign([target], value),
func_def) func_def)
param_init.append(assign) arg_init.append(assign)
# Commit arguments to the local namespace at the end to handle cases
# such as f(x, y=x) (for the default value of y, x must be resolved
# using the global namespace).
for arg_name, mangled_name in arg_name_map:
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
func_def = func_tr.code_visit(func_def) func_def = func_tr.code_visit(func_def)
func_def.body[0:0] = param_init func_def.body[0:0] = arg_init
return func_def return func_def
@ -231,6 +299,8 @@ class Function:
def code_visit_Call(self, node): def code_visit_Call(self, node):
func = self.static_visit(node.func) func = self.static_visit(node.func)
node.args = [self.code_visit(arg) for arg in node.args] node.args = [self.code_visit(arg) for arg in node.args]
for kw in node.keywords:
kw.value = self.code_visit(kw.value)
if is_embeddable(func): if is_embeddable(func):
node.func = ast.copy_location( node.func = ast.copy_location(
@ -241,10 +311,12 @@ class Function:
retval_name = func.k_function_info.k_function.__name__ + "_return" retval_name = func.k_function_info.k_function.__name__ + "_return"
retval_name_m = new_mangled_name(self.in_use_names, retval_name) retval_name_m = new_mangled_name(self.in_use_names, retval_name)
args = [func.__self__] + node.args args = [func.__self__] + node.args
kwargs = {kw.arg: kw.value for kw in node.keywords}
inlined = get_inline(self.core, inlined = get_inline(self.core,
self.attribute_namespace, self.in_use_names, self.attribute_namespace, self.in_use_names,
retval_name_m, self.mappers, retval_name_m, self.mappers,
func.k_function_info.k_function, args) func.k_function_info.k_function,
args, kwargs)
seq = ast.copy_location( seq = ast.copy_location(
ast.With( ast.With(
items=[ast.withitem(context_expr=ast.Name(id="sequential", items=[ast.withitem(context_expr=ast.Name(id="sequential",
@ -419,11 +491,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
def inline(core, k_function, k_args, k_kwargs): def inline(core, k_function, k_args, k_kwargs):
if k_kwargs: # OrderedDict prevents non-determinism in attribute init
raise NotImplementedError(
"Keyword arguments are not supported for kernels")
# OrderedDict prevents non-determinism in parameter init
attribute_namespace = OrderedDict() attribute_namespace = OrderedDict()
in_use_names = {func.__name__ for func in embeddable_funcs} in_use_names = {func.__name__ for func in embeddable_funcs}
mappers = types.SimpleNamespace( mappers = types.SimpleNamespace(
@ -437,7 +505,8 @@ def inline(core, k_function, k_args, k_kwargs):
retval_name=None, retval_name=None,
mappers=mappers, mappers=mappers,
func=k_function, func=k_function,
args=k_args) args=k_args,
kwargs=k_kwargs)
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def) func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc, func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,