diff --git a/artiq/transforms/inline.py b/artiq/transforms/inline.py index 4c3b661e5..891a4aeb4 100644 --- a/artiq/transforms/inline.py +++ b/artiq/transforms/inline.py @@ -122,7 +122,7 @@ class _ReferenceManager: raise NotImplementedError -_embeddable_calls = ( +_embeddable_funcs = ( core_language.delay, core_language.at, core_language.now, core_language.time_to_cycles, core_language.cycles_to_time, core_language.syscall, @@ -131,13 +131,22 @@ _embeddable_calls = ( Fraction, units.Quantity, core_language.EncodedException ) -def _is_embeddable(call): - for ec in _embeddable_calls: - if call is ec: +def _is_embeddable(func): + for ef in _embeddable_funcs: + if func is ef: return True return False +def _is_inlinable(core, func): + if hasattr(func, "k_function_info"): + if func.k_function_info.core_name == "": + return True # portable function + if getattr(func.__self__, func.k_function_info.core_name) is core: + return True # kernel function for the same core device + return False + + class _ReferenceReplacer(ast.NodeVisitor): def __init__(self, core, rm, obj, func_name, retval_name): self.core = core @@ -231,9 +240,7 @@ class _ReferenceReplacer(ast.NodeVisitor): ast.Call(func=new_func, args=new_args, keywords=[], starargs=None, kwargs=None), node) - elif (hasattr(func, "k_function_info") - and getattr(func.__self__, func.k_function_info.core_name) - is self.core): + elif _is_inlinable(self.core, func): retval_name = self.rm.new_name( func.k_function_info.k_function.__name__ + "_return") args = [func.__self__] + new_args