forked from M-Labs/artiq
Initial invocation of a @kernel function can now return a value (fixes #197).
This commit is contained in:
parent
e9afe5a93b
commit
4fb1de33c9
@ -139,11 +139,15 @@ class ASTSynthesizer:
|
||||
return asttyped.QuoteT(value=value, type=instance_type,
|
||||
loc=loc)
|
||||
|
||||
def call(self, function_node, args, kwargs):
|
||||
def call(self, function_node, args, kwargs, callback=None):
|
||||
"""
|
||||
Construct an AST fragment calling a function specified by
|
||||
an AST node `function_node`, with given arguments.
|
||||
"""
|
||||
if callback is not None:
|
||||
callback_node = self.quote(callback)
|
||||
cb_begin_loc = self._add("(")
|
||||
|
||||
arg_nodes = []
|
||||
kwarg_nodes = []
|
||||
kwarg_locs = []
|
||||
@ -165,7 +169,10 @@ class ASTSynthesizer:
|
||||
self._add(", ")
|
||||
end_loc = self._add(")")
|
||||
|
||||
return asttyped.CallT(
|
||||
if callback is not None:
|
||||
cb_end_loc = self._add(")")
|
||||
|
||||
node = asttyped.CallT(
|
||||
func=asttyped.NameT(id=function_node.name, ctx=None,
|
||||
type=function_node.signature_type,
|
||||
loc=name_loc),
|
||||
@ -180,6 +187,16 @@ class ASTSynthesizer:
|
||||
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
||||
loc=name_loc.join(end_loc))
|
||||
|
||||
if callback is not None:
|
||||
node = asttyped.CallT(
|
||||
func=callback_node,
|
||||
args=[node], keywords=[], starargs=None, kwargs=None,
|
||||
type=builtins.TNone(), iodelay=None,
|
||||
begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None,
|
||||
loc=callback_node.loc.join(cb_end_loc))
|
||||
|
||||
return node
|
||||
|
||||
def assign_local(self, var_name, value):
|
||||
name_loc = self._add(var_name)
|
||||
_ = self._add(" ")
|
||||
@ -426,14 +443,14 @@ class Stitcher:
|
||||
self.type_map = {}
|
||||
self.value_map = defaultdict(lambda: [])
|
||||
|
||||
def stitch_call(self, function, args, kwargs):
|
||||
def stitch_call(self, function, args, kwargs, callback=None):
|
||||
function_node = self._quote_embedded_function(function)
|
||||
self.typedtree.append(function_node)
|
||||
|
||||
# We synthesize source code for the initial call so that
|
||||
# diagnostics would have something meaningful to display to the user.
|
||||
synthesizer = self._synthesizer()
|
||||
call_node = synthesizer.call(function_node, args, kwargs)
|
||||
call_node = synthesizer.call(function_node, args, kwargs, callback)
|
||||
synthesizer.finalize()
|
||||
self.typedtree.append(call_node)
|
||||
|
||||
|
@ -907,6 +907,9 @@ class LLVMIRGenerator:
|
||||
|
||||
llargs = []
|
||||
for arg in args:
|
||||
if builtins.is_none(arg.type):
|
||||
llargslot = self.llbuilder.alloca(ll.LiteralStructType([]))
|
||||
else:
|
||||
llarg = self.map(arg)
|
||||
llargslot = self.llbuilder.alloca(llarg.type)
|
||||
self.llbuilder.store(llarg, llargslot)
|
||||
|
@ -56,12 +56,12 @@ class Core:
|
||||
self.core = self
|
||||
self.comm.core = self
|
||||
|
||||
def compile(self, function, args, kwargs, with_attr_writeback=True):
|
||||
def compile(self, function, args, kwargs, set_result, with_attr_writeback=True):
|
||||
try:
|
||||
engine = diagnostic.Engine(all_errors_are_fatal=True)
|
||||
|
||||
stitcher = Stitcher(engine=engine)
|
||||
stitcher.stitch_call(function, args, kwargs)
|
||||
stitcher.stitch_call(function, args, kwargs, set_result)
|
||||
stitcher.finalize()
|
||||
|
||||
module = Module(stitcher, ref_period=self.ref_period)
|
||||
@ -76,7 +76,12 @@ class Core:
|
||||
raise CompileError(error.diagnostic) from error
|
||||
|
||||
def run(self, function, args, kwargs):
|
||||
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
|
||||
result = None
|
||||
def set_result(new_result):
|
||||
nonlocal result
|
||||
result = new_result
|
||||
|
||||
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs, set_result)
|
||||
|
||||
if self.first_run:
|
||||
self.comm.check_ident()
|
||||
@ -87,6 +92,8 @@ class Core:
|
||||
self.comm.run()
|
||||
self.comm.serve(object_map, symbolizer)
|
||||
|
||||
return result
|
||||
|
||||
@kernel
|
||||
def get_rtio_counter_mu(self):
|
||||
return rtio_get_counter()
|
||||
|
@ -50,12 +50,10 @@ class DefaultArg(EnvExperiment):
|
||||
return foo
|
||||
|
||||
@kernel
|
||||
def run(self, callback):
|
||||
callback(self.test())
|
||||
def run(self):
|
||||
return self.test()
|
||||
|
||||
class DefaultArgTest(ExperimentCase):
|
||||
def test_default_arg(self):
|
||||
exp = self.create(DefaultArg)
|
||||
def callback(value):
|
||||
self.assertEqual(value, 42)
|
||||
exp.run(callback)
|
||||
self.assertEqual(exp.run(), 42)
|
||||
|
Loading…
Reference in New Issue
Block a user