compiler: add support for async RPCs.

This commit is contained in:
whitequark 2016-10-30 00:40:26 +00:00
parent 2ac85cd40f
commit cd68577dbc
9 changed files with 112 additions and 17 deletions

View File

@ -994,8 +994,24 @@ class Stitcher:
else: else:
assert False assert False
is_async = False
if hasattr(host_function, "artiq_embedded") and \
"async" in host_function.artiq_embedded.flags:
is_async = True
if not builtins.is_none(ret_type) and is_async:
note = diagnostic.Diagnostic("note",
"function called here", {},
loc)
diag = diagnostic.Diagnostic("fatal",
"functions that return a value cannot be defined as async RPCs", {},
self._function_loc(host_function.artiq_embedded.function),
notes=[note])
self.engine.process(diag)
function_type = types.TRPC(ret_type, function_type = types.TRPC(ret_type,
service=self.embedding_map.store_object(host_function)) service=self.embedding_map.store_object(host_function),
async=is_async)
self.functions[function] = function_type self.functions[function] = function_type
return function_type return function_type
@ -1007,7 +1023,11 @@ class Stitcher:
if function in self.functions: if function in self.functions:
pass pass
elif not hasattr(host_function, "artiq_embedded"): elif not hasattr(host_function, "artiq_embedded") or \
(host_function.artiq_embedded.core_name is None and
host_function.artiq_embedded.portable is False and
host_function.artiq_embedded.syscall is None and
host_function.artiq_embedded.forbidden is False):
self._quote_rpc(function, loc) self._quote_rpc(function, loc)
elif host_function.artiq_embedded.function is not None: elif host_function.artiq_embedded.function is not None:
if host_function.__name__ == "<lambda>": if host_function.__name__ == "<lambda>":

View File

@ -31,6 +31,7 @@ def globals():
# ARTIQ decorators # ARTIQ decorators
"kernel": builtins.fn_kernel(), "kernel": builtins.fn_kernel(),
"portable": builtins.fn_kernel(), "portable": builtins.fn_kernel(),
"rpc": builtins.fn_kernel(),
# ARTIQ context managers # ARTIQ context managers
"parallel": builtins.obj_parallel(), "parallel": builtins.obj_parallel(),

View File

@ -509,7 +509,7 @@ class ASTTypedRewriter(algorithm.Transformer):
visit_DictComp = visit_unsupported visit_DictComp = visit_unsupported
visit_Ellipsis = visit_unsupported visit_Ellipsis = visit_unsupported
visit_GeneratorExp = visit_unsupported visit_GeneratorExp = visit_unsupported
visit_Set = visit_unsupported # visit_Set = visit_unsupported
visit_SetComp = visit_unsupported visit_SetComp = visit_unsupported
visit_Starred = visit_unsupported visit_Starred = visit_unsupported
visit_Yield = visit_unsupported visit_Yield = visit_unsupported

View File

@ -351,6 +351,8 @@ class LLVMIRGenerator:
llty = ll.FunctionType(lli32, [lldouble]) llty = ll.FunctionType(lli32, [lldouble])
elif name == "send_rpc": elif name == "send_rpc":
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr]) llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
elif name == "send_async_rpc":
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
elif name == "recv_rpc": elif name == "recv_rpc":
llty = ll.FunctionType(lli32, [llptr]) llty = ll.FunctionType(lli32, [llptr])
elif name == "now": elif name == "now":
@ -366,7 +368,8 @@ class LLVMIRGenerator:
llglobal = ll.Function(self.llmodule, llty, name) llglobal = ll.Function(self.llmodule, llty, name)
if name in ("__artiq_raise", "__artiq_reraise", "llvm.trap"): if name in ("__artiq_raise", "__artiq_reraise", "llvm.trap"):
llglobal.attributes.add("noreturn") llglobal.attributes.add("noreturn")
if name in ("rtio_log", "send_rpc", "watchdog_set", "watchdog_clear", if name in ("rtio_log", "send_rpc", "send_async_rpc",
"watchdog_set", "watchdog_clear",
self.target.print_function): self.target.print_function):
llglobal.attributes.add("nounwind") llglobal.attributes.add("nounwind")
else: else:
@ -1248,12 +1251,19 @@ class LLVMIRGenerator:
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)]) llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
self.llbuilder.store(llargslot, llargptr) self.llbuilder.store(llargslot, llargptr)
if fun_type.async:
self.llbuilder.call(self.llbuiltin("send_async_rpc"),
[llservice, lltag, llargs])
else:
self.llbuilder.call(self.llbuiltin("send_rpc"), self.llbuilder.call(self.llbuiltin("send_rpc"),
[llservice, lltag, llargs]) [llservice, lltag, llargs])
# Don't waste stack space on saved arguments. # Don't waste stack space on saved arguments.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
if fun_type.async:
return ll.Undefined
# T result = { # T result = {
# void *ptr = NULL; # void *ptr = NULL;
# loop: int size = rpc_recv("tag", ptr); # loop: int size = rpc_recv("tag", ptr);

View File

@ -308,20 +308,22 @@ class TRPC(Type):
:ivar ret: (:class:`Type`) :ivar ret: (:class:`Type`)
return type return type
:ivar service: (int) RPC service number :ivar service: (int) RPC service number
:ivar async: (bool) whether the RPC blocks until return
""" """
attributes = OrderedDict() attributes = OrderedDict()
def __init__(self, ret, service): def __init__(self, ret, service, async=False):
assert isinstance(ret, Type) assert isinstance(ret, Type)
self.ret, self.service = ret, service self.ret, self.service, self.async = ret, service, async
def find(self): def find(self):
return self return self
def unify(self, other): def unify(self, other):
if isinstance(other, TRPC) and \ if isinstance(other, TRPC) and \
self.service == other.service: self.service == other.service and \
self.async == other.async:
self.ret.unify(other.ret) self.ret.unify(other.ret)
elif isinstance(other, TVar): elif isinstance(other, TVar):
other.unify(self) other.unify(self)
@ -337,7 +339,8 @@ class TRPC(Type):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, TRPC) and \ return isinstance(other, TRPC) and \
self.service == other.service self.service == other.service and \
self.async == other.async
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
@ -727,7 +730,9 @@ class TypePrinter(object):
elif isinstance(typ, TFunction): elif isinstance(typ, TFunction):
return signature return signature
elif isinstance(typ, TRPC): elif isinstance(typ, TRPC):
return "[rpc #{}](...)->{}".format(typ.service, self.name(typ.ret, depth + 1)) return "[rpc{} #{}](...)->{}".format(typ.service,
" async" if typ.async else "",
self.name(typ.ret, depth + 1))
elif isinstance(typ, TBuiltinFunction): elif isinstance(typ, TBuiltinFunction):
return "<function {}>".format(typ.name) return "<function {}>".format(typ.name)
elif isinstance(typ, (TConstructor, TExceptionConstructor)): elif isinstance(typ, (TConstructor, TExceptionConstructor)):

View File

@ -7,7 +7,7 @@ from functools import wraps
import numpy import numpy
__all__ = ["kernel", "portable", "syscall", "host_only", __all__ = ["kernel", "portable", "rpc", "syscall", "host_only",
"set_time_manager", "set_watchdog_factory", "set_time_manager", "set_watchdog_factory",
"TerminationRequested"] "TerminationRequested"]
@ -22,7 +22,7 @@ __all__.extend(kernel_globals)
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", _ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
"core_name function syscall forbidden flags") "core_name portable function syscall forbidden flags")
def kernel(arg=None, flags={}): def kernel(arg=None, flags={}):
""" """
@ -53,7 +53,7 @@ def kernel(arg=None, flags={}):
def run_on_core(self, *k_args, **k_kwargs): def run_on_core(self, *k_args, **k_kwargs):
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs) return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo( run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
core_name=arg, function=function, syscall=None, core_name=arg, portable=False, function=function, syscall=None,
forbidden=False, flags=set(flags)) forbidden=False, flags=set(flags))
return run_on_core return run_on_core
return inner_decorator return inner_decorator
@ -83,7 +83,23 @@ def portable(arg=None, flags={}):
return inner_decorator return inner_decorator
else: else:
arg.artiq_embedded = \ arg.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=arg, syscall=None, _ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None,
forbidden=False, flags=set(flags))
return arg
def rpc(arg=None, flags={}):
"""
This decorator marks a function for execution on the host interpreter.
This is also the default behavior of ARTIQ; however, this decorator allows
specifying additional flags.
"""
if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
else:
arg.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None,
forbidden=False, flags=set(flags)) forbidden=False, flags=set(flags))
return arg return arg
@ -101,7 +117,7 @@ def syscall(arg=None, flags={}):
if isinstance(arg, str): if isinstance(arg, str):
def inner_decorator(function): def inner_decorator(function):
function.artiq_embedded = \ function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=None, _ARTIQEmbeddedInfo(core_name=None, portable=False, function=None,
syscall=function.__name__, forbidden=False, syscall=function.__name__, forbidden=False,
flags=set(flags)) flags=set(flags))
return function return function
@ -119,7 +135,7 @@ def host_only(function):
in the host Python interpreter. in the host Python interpreter.
""" """
function.artiq_embedded = \ function.artiq_embedded = \
_ARTIQEmbeddedInfo(core_name=None, function=None, syscall=None, _ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None,
forbidden=True, flags={}) forbidden=True, flags={})
return function return function

View File

@ -0,0 +1,15 @@
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
# RUN: OutputCheck %s --file-to-check=%t.ll
from artiq.language.core import *
from artiq.language.types import *
# CHECK: call void @send_async_rpc
@rpc(flags={"async"})
def foo():
pass
@kernel
def entrypoint():
foo()

View File

@ -0,0 +1,15 @@
# RUN: %python -m artiq.compiler.testbench.embedding +diag %s 2>%t
# RUN: OutputCheck %s --file-to-check=%t
from artiq.language.core import *
from artiq.language.types import *
# CHECK-L: ${LINE:+2}: fatal: functions that return a value cannot be defined as async RPCs
@rpc(flags={"async"})
def foo() -> TInt32:
pass
@kernel
def entrypoint():
# CHECK-L: ${LINE:+1}: note: function called here
foo()

View File

@ -47,6 +47,19 @@ The Python types correspond to ARTIQ type annotations as follows:
| range | TRange32, TRange64 | | range | TRange32, TRange64 |
+-------------+-------------------------+ +-------------+-------------------------+
Asynchronous RPCs
-----------------
If an RPC returns no value, it can be invoked in a way that does not block until the RPC finishes
execution, but only until it is queued. (Submitting asynchronous RPCs too rapidly, as well as
submitting asynchronous RPCs with arguments that are too large, can still block until completion.)
To define an asynchronous RPC, use the ``@rpc`` annotation with a flag:
@rpc(flags={"async"})
def record_result(x):
self.results.append(x)
Additional optimizations Additional optimizations
------------------------ ------------------------