diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 806a0f856..97e3fb43d 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -139,11 +139,15 @@ class ASTSynthesizer: return asttyped.QuoteT(value=value, type=instance_type, loc=loc) - def call(self, function_node, args, kwargs): + def call(self, function_node, args, kwargs, callback=None): """ Construct an AST fragment calling a function specified by an AST node `function_node`, with given arguments. """ + if callback is not None: + callback_node = self.quote(callback) + cb_begin_loc = self._add("(") + arg_nodes = [] kwarg_nodes = [] kwarg_locs = [] @@ -165,7 +169,10 @@ class ASTSynthesizer: self._add(", ") end_loc = self._add(")") - return asttyped.CallT( + if callback is not None: + cb_end_loc = self._add(")") + + node = asttyped.CallT( func=asttyped.NameT(id=function_node.name, ctx=None, type=function_node.signature_type, loc=name_loc), @@ -180,6 +187,16 @@ class ASTSynthesizer: begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, loc=name_loc.join(end_loc)) + if callback is not None: + node = asttyped.CallT( + func=callback_node, + args=[node], keywords=[], starargs=None, kwargs=None, + type=builtins.TNone(), iodelay=None, + begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None, + loc=callback_node.loc.join(cb_end_loc)) + + return node + def assign_local(self, var_name, value): name_loc = self._add(var_name) _ = self._add(" ") @@ -426,14 +443,14 @@ class Stitcher: self.type_map = {} self.value_map = defaultdict(lambda: []) - def stitch_call(self, function, args, kwargs): + def stitch_call(self, function, args, kwargs, callback=None): function_node = self._quote_embedded_function(function) self.typedtree.append(function_node) # We synthesize source code for the initial call so that # diagnostics would have something meaningful to display to the user. synthesizer = self._synthesizer() - call_node = synthesizer.call(function_node, args, kwargs) + call_node = synthesizer.call(function_node, args, kwargs, callback) synthesizer.finalize() self.typedtree.append(call_node) diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 127fe32c0..b563476e1 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -907,9 +907,12 @@ class LLVMIRGenerator: llargs = [] for arg in args: - llarg = self.map(arg) - llargslot = self.llbuilder.alloca(llarg.type) - self.llbuilder.store(llarg, llargslot) + if builtins.is_none(arg.type): + llargslot = self.llbuilder.alloca(ll.LiteralStructType([])) + else: + llarg = self.map(arg) + llargslot = self.llbuilder.alloca(llarg.type) + self.llbuilder.store(llarg, llargslot) llargs.append(llargslot) self.llbuilder.call(self.llbuiltin("send_rpc"), diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 217d1f3f3..c444577c5 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -56,12 +56,12 @@ class Core: self.core = self self.comm.core = self - def compile(self, function, args, kwargs, with_attr_writeback=True): + def compile(self, function, args, kwargs, set_result, with_attr_writeback=True): try: engine = diagnostic.Engine(all_errors_are_fatal=True) stitcher = Stitcher(engine=engine) - stitcher.stitch_call(function, args, kwargs) + stitcher.stitch_call(function, args, kwargs, set_result) stitcher.finalize() module = Module(stitcher, ref_period=self.ref_period) @@ -76,7 +76,12 @@ class Core: raise CompileError(error.diagnostic) from error def run(self, function, args, kwargs): - object_map, kernel_library, symbolizer = self.compile(function, args, kwargs) + result = None + def set_result(new_result): + nonlocal result + result = new_result + + object_map, kernel_library, symbolizer = self.compile(function, args, kwargs, set_result) if self.first_run: self.comm.check_ident() @@ -87,6 +92,8 @@ class Core: self.comm.run() self.comm.serve(object_map, symbolizer) + return result + @kernel def get_rtio_counter_mu(self): return rtio_get_counter() diff --git a/artiq/test/coredevice/embedding.py b/artiq/test/coredevice/embedding.py index a5a7095ad..04a136164 100644 --- a/artiq/test/coredevice/embedding.py +++ b/artiq/test/coredevice/embedding.py @@ -50,12 +50,10 @@ class DefaultArg(EnvExperiment): return foo @kernel - def run(self, callback): - callback(self.test()) + def run(self): + return self.test() class DefaultArgTest(ExperimentCase): def test_default_arg(self): exp = self.create(DefaultArg) - def callback(value): - self.assertEqual(value, 42) - exp.run(callback) + self.assertEqual(exp.run(), 42)