Initial invocation of a @kernel function can now return a value (fixes #197).

This commit is contained in:
whitequark 2015-12-19 05:26:18 +08:00
parent e9afe5a93b
commit 4fb1de33c9
4 changed files with 40 additions and 15 deletions

View File

@ -139,11 +139,15 @@ class ASTSynthesizer:
return asttyped.QuoteT(value=value, type=instance_type, return asttyped.QuoteT(value=value, type=instance_type,
loc=loc) loc=loc)
def call(self, function_node, args, kwargs): def call(self, function_node, args, kwargs, callback=None):
""" """
Construct an AST fragment calling a function specified by Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments. an AST node `function_node`, with given arguments.
""" """
if callback is not None:
callback_node = self.quote(callback)
cb_begin_loc = self._add("(")
arg_nodes = [] arg_nodes = []
kwarg_nodes = [] kwarg_nodes = []
kwarg_locs = [] kwarg_locs = []
@ -165,7 +169,10 @@ class ASTSynthesizer:
self._add(", ") self._add(", ")
end_loc = self._add(")") end_loc = self._add(")")
return asttyped.CallT( if callback is not None:
cb_end_loc = self._add(")")
node = asttyped.CallT(
func=asttyped.NameT(id=function_node.name, ctx=None, func=asttyped.NameT(id=function_node.name, ctx=None,
type=function_node.signature_type, type=function_node.signature_type,
loc=name_loc), loc=name_loc),
@ -180,6 +187,16 @@ class ASTSynthesizer:
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc)) loc=name_loc.join(end_loc))
if callback is not None:
node = asttyped.CallT(
func=callback_node,
args=[node], keywords=[], starargs=None, kwargs=None,
type=builtins.TNone(), iodelay=None,
begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None,
loc=callback_node.loc.join(cb_end_loc))
return node
def assign_local(self, var_name, value): def assign_local(self, var_name, value):
name_loc = self._add(var_name) name_loc = self._add(var_name)
_ = self._add(" ") _ = self._add(" ")
@ -426,14 +443,14 @@ class Stitcher:
self.type_map = {} self.type_map = {}
self.value_map = defaultdict(lambda: []) self.value_map = defaultdict(lambda: [])
def stitch_call(self, function, args, kwargs): def stitch_call(self, function, args, kwargs, callback=None):
function_node = self._quote_embedded_function(function) function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node) self.typedtree.append(function_node)
# We synthesize source code for the initial call so that # We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user. # diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer() synthesizer = self._synthesizer()
call_node = synthesizer.call(function_node, args, kwargs) call_node = synthesizer.call(function_node, args, kwargs, callback)
synthesizer.finalize() synthesizer.finalize()
self.typedtree.append(call_node) self.typedtree.append(call_node)

View File

@ -907,9 +907,12 @@ class LLVMIRGenerator:
llargs = [] llargs = []
for arg in args: for arg in args:
llarg = self.map(arg) if builtins.is_none(arg.type):
llargslot = self.llbuilder.alloca(llarg.type) llargslot = self.llbuilder.alloca(ll.LiteralStructType([]))
self.llbuilder.store(llarg, llargslot) else:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot) llargs.append(llargslot)
self.llbuilder.call(self.llbuiltin("send_rpc"), self.llbuilder.call(self.llbuiltin("send_rpc"),

View File

@ -56,12 +56,12 @@ class Core:
self.core = self self.core = self
self.comm.core = self self.comm.core = self
def compile(self, function, args, kwargs, with_attr_writeback=True): def compile(self, function, args, kwargs, set_result, with_attr_writeback=True):
try: try:
engine = diagnostic.Engine(all_errors_are_fatal=True) engine = diagnostic.Engine(all_errors_are_fatal=True)
stitcher = Stitcher(engine=engine) stitcher = Stitcher(engine=engine)
stitcher.stitch_call(function, args, kwargs) stitcher.stitch_call(function, args, kwargs, set_result)
stitcher.finalize() stitcher.finalize()
module = Module(stitcher, ref_period=self.ref_period) module = Module(stitcher, ref_period=self.ref_period)
@ -76,7 +76,12 @@ class Core:
raise CompileError(error.diagnostic) from error raise CompileError(error.diagnostic) from error
def run(self, function, args, kwargs): def run(self, function, args, kwargs):
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs) result = None
def set_result(new_result):
nonlocal result
result = new_result
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs, set_result)
if self.first_run: if self.first_run:
self.comm.check_ident() self.comm.check_ident()
@ -87,6 +92,8 @@ class Core:
self.comm.run() self.comm.run()
self.comm.serve(object_map, symbolizer) self.comm.serve(object_map, symbolizer)
return result
@kernel @kernel
def get_rtio_counter_mu(self): def get_rtio_counter_mu(self):
return rtio_get_counter() return rtio_get_counter()

View File

@ -50,12 +50,10 @@ class DefaultArg(EnvExperiment):
return foo return foo
@kernel @kernel
def run(self, callback): def run(self):
callback(self.test()) return self.test()
class DefaultArgTest(ExperimentCase): class DefaultArgTest(ExperimentCase):
def test_default_arg(self): def test_default_arg(self):
exp = self.create(DefaultArg) exp = self.create(DefaultArg)
def callback(value): self.assertEqual(exp.run(), 42)
self.assertEqual(value, 42)
exp.run(callback)