Initial invocation of a @kernel function can now return a value (fixes #197).

This commit is contained in:
whitequark 2015-12-19 05:26:18 +08:00
parent e9afe5a93b
commit 4fb1de33c9
4 changed files with 40 additions and 15 deletions

View File

@ -139,11 +139,15 @@ class ASTSynthesizer:
return asttyped.QuoteT(value=value, type=instance_type,
loc=loc)
def call(self, function_node, args, kwargs):
def call(self, function_node, args, kwargs, callback=None):
"""
Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments.
"""
if callback is not None:
callback_node = self.quote(callback)
cb_begin_loc = self._add("(")
arg_nodes = []
kwarg_nodes = []
kwarg_locs = []
@ -165,7 +169,10 @@ class ASTSynthesizer:
self._add(", ")
end_loc = self._add(")")
return asttyped.CallT(
if callback is not None:
cb_end_loc = self._add(")")
node = asttyped.CallT(
func=asttyped.NameT(id=function_node.name, ctx=None,
type=function_node.signature_type,
loc=name_loc),
@ -180,6 +187,16 @@ class ASTSynthesizer:
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc))
if callback is not None:
node = asttyped.CallT(
func=callback_node,
args=[node], keywords=[], starargs=None, kwargs=None,
type=builtins.TNone(), iodelay=None,
begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None,
loc=callback_node.loc.join(cb_end_loc))
return node
def assign_local(self, var_name, value):
name_loc = self._add(var_name)
_ = self._add(" ")
@ -426,14 +443,14 @@ class Stitcher:
self.type_map = {}
self.value_map = defaultdict(lambda: [])
def stitch_call(self, function, args, kwargs):
def stitch_call(self, function, args, kwargs, callback=None):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)
# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
call_node = synthesizer.call(function_node, args, kwargs, callback)
synthesizer.finalize()
self.typedtree.append(call_node)

View File

@ -907,9 +907,12 @@ class LLVMIRGenerator:
llargs = []
for arg in args:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
if builtins.is_none(arg.type):
llargslot = self.llbuilder.alloca(ll.LiteralStructType([]))
else:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)
self.llbuilder.call(self.llbuiltin("send_rpc"),

View File

@ -56,12 +56,12 @@ class Core:
self.core = self
self.comm.core = self
def compile(self, function, args, kwargs, with_attr_writeback=True):
def compile(self, function, args, kwargs, set_result, with_attr_writeback=True):
try:
engine = diagnostic.Engine(all_errors_are_fatal=True)
stitcher = Stitcher(engine=engine)
stitcher.stitch_call(function, args, kwargs)
stitcher.stitch_call(function, args, kwargs, set_result)
stitcher.finalize()
module = Module(stitcher, ref_period=self.ref_period)
@ -76,7 +76,12 @@ class Core:
raise CompileError(error.diagnostic) from error
def run(self, function, args, kwargs):
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
result = None
def set_result(new_result):
nonlocal result
result = new_result
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs, set_result)
if self.first_run:
self.comm.check_ident()
@ -87,6 +92,8 @@ class Core:
self.comm.run()
self.comm.serve(object_map, symbolizer)
return result
@kernel
def get_rtio_counter_mu(self):
return rtio_get_counter()

View File

@ -50,12 +50,10 @@ class DefaultArg(EnvExperiment):
return foo
@kernel
def run(self, callback):
callback(self.test())
def run(self):
return self.test()
class DefaultArgTest(ExperimentCase):
def test_default_arg(self):
exp = self.create(DefaultArg)
def callback(value):
self.assertEqual(value, 42)
exp.run(callback)
self.assertEqual(exp.run(), 42)