forked from M-Labs/artiq
Initial invocation of a @kernel function can now return a value (fixes #197).
This commit is contained in:
parent
e9afe5a93b
commit
4fb1de33c9
@ -139,11 +139,15 @@ class ASTSynthesizer:
|
|||||||
return asttyped.QuoteT(value=value, type=instance_type,
|
return asttyped.QuoteT(value=value, type=instance_type,
|
||||||
loc=loc)
|
loc=loc)
|
||||||
|
|
||||||
def call(self, function_node, args, kwargs):
|
def call(self, function_node, args, kwargs, callback=None):
|
||||||
"""
|
"""
|
||||||
Construct an AST fragment calling a function specified by
|
Construct an AST fragment calling a function specified by
|
||||||
an AST node `function_node`, with given arguments.
|
an AST node `function_node`, with given arguments.
|
||||||
"""
|
"""
|
||||||
|
if callback is not None:
|
||||||
|
callback_node = self.quote(callback)
|
||||||
|
cb_begin_loc = self._add("(")
|
||||||
|
|
||||||
arg_nodes = []
|
arg_nodes = []
|
||||||
kwarg_nodes = []
|
kwarg_nodes = []
|
||||||
kwarg_locs = []
|
kwarg_locs = []
|
||||||
@ -165,7 +169,10 @@ class ASTSynthesizer:
|
|||||||
self._add(", ")
|
self._add(", ")
|
||||||
end_loc = self._add(")")
|
end_loc = self._add(")")
|
||||||
|
|
||||||
return asttyped.CallT(
|
if callback is not None:
|
||||||
|
cb_end_loc = self._add(")")
|
||||||
|
|
||||||
|
node = asttyped.CallT(
|
||||||
func=asttyped.NameT(id=function_node.name, ctx=None,
|
func=asttyped.NameT(id=function_node.name, ctx=None,
|
||||||
type=function_node.signature_type,
|
type=function_node.signature_type,
|
||||||
loc=name_loc),
|
loc=name_loc),
|
||||||
@ -180,6 +187,16 @@ class ASTSynthesizer:
|
|||||||
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
|
||||||
loc=name_loc.join(end_loc))
|
loc=name_loc.join(end_loc))
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
node = asttyped.CallT(
|
||||||
|
func=callback_node,
|
||||||
|
args=[node], keywords=[], starargs=None, kwargs=None,
|
||||||
|
type=builtins.TNone(), iodelay=None,
|
||||||
|
begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None,
|
||||||
|
loc=callback_node.loc.join(cb_end_loc))
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
def assign_local(self, var_name, value):
|
def assign_local(self, var_name, value):
|
||||||
name_loc = self._add(var_name)
|
name_loc = self._add(var_name)
|
||||||
_ = self._add(" ")
|
_ = self._add(" ")
|
||||||
@ -426,14 +443,14 @@ class Stitcher:
|
|||||||
self.type_map = {}
|
self.type_map = {}
|
||||||
self.value_map = defaultdict(lambda: [])
|
self.value_map = defaultdict(lambda: [])
|
||||||
|
|
||||||
def stitch_call(self, function, args, kwargs):
|
def stitch_call(self, function, args, kwargs, callback=None):
|
||||||
function_node = self._quote_embedded_function(function)
|
function_node = self._quote_embedded_function(function)
|
||||||
self.typedtree.append(function_node)
|
self.typedtree.append(function_node)
|
||||||
|
|
||||||
# We synthesize source code for the initial call so that
|
# We synthesize source code for the initial call so that
|
||||||
# diagnostics would have something meaningful to display to the user.
|
# diagnostics would have something meaningful to display to the user.
|
||||||
synthesizer = self._synthesizer()
|
synthesizer = self._synthesizer()
|
||||||
call_node = synthesizer.call(function_node, args, kwargs)
|
call_node = synthesizer.call(function_node, args, kwargs, callback)
|
||||||
synthesizer.finalize()
|
synthesizer.finalize()
|
||||||
self.typedtree.append(call_node)
|
self.typedtree.append(call_node)
|
||||||
|
|
||||||
|
@ -907,9 +907,12 @@ class LLVMIRGenerator:
|
|||||||
|
|
||||||
llargs = []
|
llargs = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
llarg = self.map(arg)
|
if builtins.is_none(arg.type):
|
||||||
llargslot = self.llbuilder.alloca(llarg.type)
|
llargslot = self.llbuilder.alloca(ll.LiteralStructType([]))
|
||||||
self.llbuilder.store(llarg, llargslot)
|
else:
|
||||||
|
llarg = self.map(arg)
|
||||||
|
llargslot = self.llbuilder.alloca(llarg.type)
|
||||||
|
self.llbuilder.store(llarg, llargslot)
|
||||||
llargs.append(llargslot)
|
llargs.append(llargslot)
|
||||||
|
|
||||||
self.llbuilder.call(self.llbuiltin("send_rpc"),
|
self.llbuilder.call(self.llbuiltin("send_rpc"),
|
||||||
|
@ -56,12 +56,12 @@ class Core:
|
|||||||
self.core = self
|
self.core = self
|
||||||
self.comm.core = self
|
self.comm.core = self
|
||||||
|
|
||||||
def compile(self, function, args, kwargs, with_attr_writeback=True):
|
def compile(self, function, args, kwargs, set_result, with_attr_writeback=True):
|
||||||
try:
|
try:
|
||||||
engine = diagnostic.Engine(all_errors_are_fatal=True)
|
engine = diagnostic.Engine(all_errors_are_fatal=True)
|
||||||
|
|
||||||
stitcher = Stitcher(engine=engine)
|
stitcher = Stitcher(engine=engine)
|
||||||
stitcher.stitch_call(function, args, kwargs)
|
stitcher.stitch_call(function, args, kwargs, set_result)
|
||||||
stitcher.finalize()
|
stitcher.finalize()
|
||||||
|
|
||||||
module = Module(stitcher, ref_period=self.ref_period)
|
module = Module(stitcher, ref_period=self.ref_period)
|
||||||
@ -76,7 +76,12 @@ class Core:
|
|||||||
raise CompileError(error.diagnostic) from error
|
raise CompileError(error.diagnostic) from error
|
||||||
|
|
||||||
def run(self, function, args, kwargs):
|
def run(self, function, args, kwargs):
|
||||||
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
|
result = None
|
||||||
|
def set_result(new_result):
|
||||||
|
nonlocal result
|
||||||
|
result = new_result
|
||||||
|
|
||||||
|
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs, set_result)
|
||||||
|
|
||||||
if self.first_run:
|
if self.first_run:
|
||||||
self.comm.check_ident()
|
self.comm.check_ident()
|
||||||
@ -87,6 +92,8 @@ class Core:
|
|||||||
self.comm.run()
|
self.comm.run()
|
||||||
self.comm.serve(object_map, symbolizer)
|
self.comm.serve(object_map, symbolizer)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def get_rtio_counter_mu(self):
|
def get_rtio_counter_mu(self):
|
||||||
return rtio_get_counter()
|
return rtio_get_counter()
|
||||||
|
@ -50,12 +50,10 @@ class DefaultArg(EnvExperiment):
|
|||||||
return foo
|
return foo
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self, callback):
|
def run(self):
|
||||||
callback(self.test())
|
return self.test()
|
||||||
|
|
||||||
class DefaultArgTest(ExperimentCase):
|
class DefaultArgTest(ExperimentCase):
|
||||||
def test_default_arg(self):
|
def test_default_arg(self):
|
||||||
exp = self.create(DefaultArg)
|
exp = self.create(DefaultArg)
|
||||||
def callback(value):
|
self.assertEqual(exp.run(), 42)
|
||||||
self.assertEqual(value, 42)
|
|
||||||
exp.run(callback)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user