forked from M-Labs/artiq
transforms/inline: keyword argument and default value support
This commit is contained in:
parent
abae5c6728
commit
8b552134a0
|
@ -6,6 +6,7 @@ import builtins
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from itertools import zip_longest, chain
|
||||||
|
|
||||||
from artiq.language import core as core_language
|
from artiq.language import core as core_language
|
||||||
from artiq.language import units
|
from artiq.language import units
|
||||||
|
@ -71,17 +72,79 @@ class GlobalNamespace:
|
||||||
return getattr(builtins, item)
|
return getattr(builtins, item)
|
||||||
|
|
||||||
|
|
||||||
|
class UndefinedArg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_args(func_args, func_tr, args, kwargs):
|
||||||
|
# OrderedDict prevents non-determinism in argument init
|
||||||
|
r = OrderedDict()
|
||||||
|
|
||||||
|
# Process positional arguments. Any missing positional argument values
|
||||||
|
# are set to UndefinedArg.
|
||||||
|
for arg, arg_value in zip_longest(func_args.args, args,
|
||||||
|
fillvalue=UndefinedArg):
|
||||||
|
if arg is UndefinedArg:
|
||||||
|
raise TypeError("Got too many positional arguments")
|
||||||
|
if arg.arg in r:
|
||||||
|
raise SyntaxError("Duplicate argument '{}' in function definition"
|
||||||
|
.format(arg.arg))
|
||||||
|
r[arg.arg] = arg_value
|
||||||
|
|
||||||
|
# Process keyword arguments. Any missing keyword-only argument values
|
||||||
|
# are set to UndefinedArg.
|
||||||
|
valid_arg_names = {arg.arg for arg in
|
||||||
|
chain(func_args.args, func_args.kwonlyargs)}
|
||||||
|
for arg in func_args.kwonlyargs:
|
||||||
|
if arg.arg in r:
|
||||||
|
raise SyntaxError("Duplicate argument '{}' in function definition"
|
||||||
|
.format(arg.arg))
|
||||||
|
r[arg.arg] = UndefinedArg
|
||||||
|
for arg_name, arg_value in kwargs.items():
|
||||||
|
if arg_name not in valid_arg_names:
|
||||||
|
raise TypeError("Got unexpected keyword argument '{}'"
|
||||||
|
.format(arg_name))
|
||||||
|
if r[arg_name] is not UndefinedArg:
|
||||||
|
raise TypeError("Got multiple values for argument '{}'"
|
||||||
|
.format(arg_name))
|
||||||
|
r[arg_name] = arg_value
|
||||||
|
|
||||||
|
# Replace any UndefinedArg positional arguments with the default value,
|
||||||
|
# when provided.
|
||||||
|
for arg, default in zip(func_args.args[len(func_args.defaults):],
|
||||||
|
func_args.defaults):
|
||||||
|
if r[arg.arg] is UndefinedArg:
|
||||||
|
r[arg.arg] = func_tr.code_visit(default)
|
||||||
|
# Same with keyword-only arguments.
|
||||||
|
for arg, default in zip(func_args.kwonlyargs, func_args.kw_defaults):
|
||||||
|
if default is not None and r[arg.arg] is UndefinedArg:
|
||||||
|
r[arg.arg] = func_tr.code_visit(default)
|
||||||
|
|
||||||
|
# Check that no argument was left undefined.
|
||||||
|
missing_arguments = ["'"+arg+"'" for arg, value in r.items()
|
||||||
|
if value is UndefinedArg]
|
||||||
|
if missing_arguments:
|
||||||
|
raise TypeError("Missing argument(s): " + " ".join(missing_arguments))
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
# args/kwargs can contain values or AST nodes
|
||||||
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
||||||
func, args):
|
func, args, kwargs):
|
||||||
global_namespace = GlobalNamespace(func)
|
global_namespace = GlobalNamespace(func)
|
||||||
func_tr = Function(core,
|
func_tr = Function(core,
|
||||||
global_namespace, attribute_namespace, in_use_names,
|
global_namespace, attribute_namespace, in_use_names,
|
||||||
retval_name, mappers)
|
retval_name, mappers)
|
||||||
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
|
func_def = ast.parse(textwrap.dedent(inspect.getsource(func))).body[0]
|
||||||
|
|
||||||
param_init = []
|
# Initialize arguments.
|
||||||
# initialize arguments
|
# The local namespace is empty so code_visit will always resolve
|
||||||
for arg_ast, arg_value in zip(func_def.args.args, args):
|
# using the global namespace.
|
||||||
|
arg_init = []
|
||||||
|
arg_name_map = []
|
||||||
|
arg_dict = get_function_args(func_def.args, func_tr, args, kwargs)
|
||||||
|
for arg_name, arg_value in arg_dict.items():
|
||||||
if isinstance(arg_value, ast.AST):
|
if isinstance(arg_value, ast.AST):
|
||||||
value = arg_value
|
value = arg_value
|
||||||
else:
|
else:
|
||||||
|
@ -91,20 +154,25 @@ def get_inline(core, attribute_namespace, in_use_names, retval_name, mappers,
|
||||||
value = None
|
value = None
|
||||||
if value is None:
|
if value is None:
|
||||||
# static object
|
# static object
|
||||||
func_tr.local_namespace[arg_ast.arg] = arg_value
|
func_tr.local_namespace[arg_name] = arg_value
|
||||||
else:
|
else:
|
||||||
# set parameter value with "name = value"
|
# set parameter value with "name = value"
|
||||||
# assignment at beginning of function
|
# assignment at beginning of function
|
||||||
new_name = new_mangled_name(in_use_names, arg_ast.arg)
|
new_name = new_mangled_name(in_use_names, arg_name)
|
||||||
func_tr.local_namespace[arg_ast.arg] = MangledName(new_name)
|
arg_name_map.append((arg_name, new_name))
|
||||||
target = ast.copy_location(ast.Name(new_name, ast.Store()),
|
target = ast.copy_location(ast.Name(new_name, ast.Store()),
|
||||||
func_def)
|
func_def)
|
||||||
assign = ast.copy_location(ast.Assign([target], value),
|
assign = ast.copy_location(ast.Assign([target], value),
|
||||||
func_def)
|
func_def)
|
||||||
param_init.append(assign)
|
arg_init.append(assign)
|
||||||
|
# Commit arguments to the local namespace at the end to handle cases
|
||||||
|
# such as f(x, y=x) (for the default value of y, x must be resolved
|
||||||
|
# using the global namespace).
|
||||||
|
for arg_name, mangled_name in arg_name_map:
|
||||||
|
func_tr.local_namespace[arg_name] = MangledName(mangled_name)
|
||||||
|
|
||||||
func_def = func_tr.code_visit(func_def)
|
func_def = func_tr.code_visit(func_def)
|
||||||
func_def.body[0:0] = param_init
|
func_def.body[0:0] = arg_init
|
||||||
return func_def
|
return func_def
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,6 +299,8 @@ class Function:
|
||||||
def code_visit_Call(self, node):
|
def code_visit_Call(self, node):
|
||||||
func = self.static_visit(node.func)
|
func = self.static_visit(node.func)
|
||||||
node.args = [self.code_visit(arg) for arg in node.args]
|
node.args = [self.code_visit(arg) for arg in node.args]
|
||||||
|
for kw in node.keywords:
|
||||||
|
kw.value = self.code_visit(kw.value)
|
||||||
|
|
||||||
if is_embeddable(func):
|
if is_embeddable(func):
|
||||||
node.func = ast.copy_location(
|
node.func = ast.copy_location(
|
||||||
|
@ -241,10 +311,12 @@ class Function:
|
||||||
retval_name = func.k_function_info.k_function.__name__ + "_return"
|
retval_name = func.k_function_info.k_function.__name__ + "_return"
|
||||||
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
|
retval_name_m = new_mangled_name(self.in_use_names, retval_name)
|
||||||
args = [func.__self__] + node.args
|
args = [func.__self__] + node.args
|
||||||
|
kwargs = {kw.arg: kw.value for kw in node.keywords}
|
||||||
inlined = get_inline(self.core,
|
inlined = get_inline(self.core,
|
||||||
self.attribute_namespace, self.in_use_names,
|
self.attribute_namespace, self.in_use_names,
|
||||||
retval_name_m, self.mappers,
|
retval_name_m, self.mappers,
|
||||||
func.k_function_info.k_function, args)
|
func.k_function_info.k_function,
|
||||||
|
args, kwargs)
|
||||||
seq = ast.copy_location(
|
seq = ast.copy_location(
|
||||||
ast.With(
|
ast.With(
|
||||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||||
|
@ -419,11 +491,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
|
||||||
|
|
||||||
|
|
||||||
def inline(core, k_function, k_args, k_kwargs):
|
def inline(core, k_function, k_args, k_kwargs):
|
||||||
if k_kwargs:
|
# OrderedDict prevents non-determinism in attribute init
|
||||||
raise NotImplementedError(
|
|
||||||
"Keyword arguments are not supported for kernels")
|
|
||||||
|
|
||||||
# OrderedDict prevents non-determinism in parameter init
|
|
||||||
attribute_namespace = OrderedDict()
|
attribute_namespace = OrderedDict()
|
||||||
in_use_names = {func.__name__ for func in embeddable_funcs}
|
in_use_names = {func.__name__ for func in embeddable_funcs}
|
||||||
mappers = types.SimpleNamespace(
|
mappers = types.SimpleNamespace(
|
||||||
|
@ -437,10 +505,11 @@ def inline(core, k_function, k_args, k_kwargs):
|
||||||
retval_name=None,
|
retval_name=None,
|
||||||
mappers=mappers,
|
mappers=mappers,
|
||||||
func=k_function,
|
func=k_function,
|
||||||
args=k_args)
|
args=k_args,
|
||||||
|
kwargs=k_kwargs)
|
||||||
|
|
||||||
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
|
||||||
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
|
||||||
func_def)
|
func_def)
|
||||||
|
|
||||||
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
|
||||||
|
|
Loading…
Reference in New Issue