mirror of
https://github.com/m-labs/artiq.git
synced 2024-12-29 13:13:34 +08:00
compiler: add support for async RPCs.
This commit is contained in:
parent
2ac85cd40f
commit
cd68577dbc
@ -994,8 +994,24 @@ class Stitcher:
|
||||
else:
|
||||
assert False
|
||||
|
||||
is_async = False
|
||||
if hasattr(host_function, "artiq_embedded") and \
|
||||
"async" in host_function.artiq_embedded.flags:
|
||||
is_async = True
|
||||
|
||||
if not builtins.is_none(ret_type) and is_async:
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"function called here", {},
|
||||
loc)
|
||||
diag = diagnostic.Diagnostic("fatal",
|
||||
"functions that return a value cannot be defined as async RPCs", {},
|
||||
self._function_loc(host_function.artiq_embedded.function),
|
||||
notes=[note])
|
||||
self.engine.process(diag)
|
||||
|
||||
function_type = types.TRPC(ret_type,
|
||||
service=self.embedding_map.store_object(host_function))
|
||||
service=self.embedding_map.store_object(host_function),
|
||||
async=is_async)
|
||||
self.functions[function] = function_type
|
||||
return function_type
|
||||
|
||||
@ -1007,7 +1023,11 @@ class Stitcher:
|
||||
|
||||
if function in self.functions:
|
||||
pass
|
||||
elif not hasattr(host_function, "artiq_embedded"):
|
||||
elif not hasattr(host_function, "artiq_embedded") or \
|
||||
(host_function.artiq_embedded.core_name is None and
|
||||
host_function.artiq_embedded.portable is False and
|
||||
host_function.artiq_embedded.syscall is None and
|
||||
host_function.artiq_embedded.forbidden is False):
|
||||
self._quote_rpc(function, loc)
|
||||
elif host_function.artiq_embedded.function is not None:
|
||||
if host_function.__name__ == "<lambda>":
|
||||
|
@ -31,6 +31,7 @@ def globals():
|
||||
# ARTIQ decorators
|
||||
"kernel": builtins.fn_kernel(),
|
||||
"portable": builtins.fn_kernel(),
|
||||
"rpc": builtins.fn_kernel(),
|
||||
|
||||
# ARTIQ context managers
|
||||
"parallel": builtins.obj_parallel(),
|
||||
|
@ -509,7 +509,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||
visit_DictComp = visit_unsupported
|
||||
visit_Ellipsis = visit_unsupported
|
||||
visit_GeneratorExp = visit_unsupported
|
||||
visit_Set = visit_unsupported
|
||||
# visit_Set = visit_unsupported
|
||||
visit_SetComp = visit_unsupported
|
||||
visit_Starred = visit_unsupported
|
||||
visit_Yield = visit_unsupported
|
||||
|
@ -351,6 +351,8 @@ class LLVMIRGenerator:
|
||||
llty = ll.FunctionType(lli32, [lldouble])
|
||||
elif name == "send_rpc":
|
||||
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
|
||||
elif name == "send_async_rpc":
|
||||
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
|
||||
elif name == "recv_rpc":
|
||||
llty = ll.FunctionType(lli32, [llptr])
|
||||
elif name == "now":
|
||||
@ -366,7 +368,8 @@ class LLVMIRGenerator:
|
||||
llglobal = ll.Function(self.llmodule, llty, name)
|
||||
if name in ("__artiq_raise", "__artiq_reraise", "llvm.trap"):
|
||||
llglobal.attributes.add("noreturn")
|
||||
if name in ("rtio_log", "send_rpc", "watchdog_set", "watchdog_clear",
|
||||
if name in ("rtio_log", "send_rpc", "send_async_rpc",
|
||||
"watchdog_set", "watchdog_clear",
|
||||
self.target.print_function):
|
||||
llglobal.attributes.add("nounwind")
|
||||
else:
|
||||
@ -1248,12 +1251,19 @@ class LLVMIRGenerator:
|
||||
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
|
||||
self.llbuilder.store(llargslot, llargptr)
|
||||
|
||||
self.llbuilder.call(self.llbuiltin("send_rpc"),
|
||||
[llservice, lltag, llargs])
|
||||
if fun_type.async:
|
||||
self.llbuilder.call(self.llbuiltin("send_async_rpc"),
|
||||
[llservice, lltag, llargs])
|
||||
else:
|
||||
self.llbuilder.call(self.llbuiltin("send_rpc"),
|
||||
[llservice, lltag, llargs])
|
||||
|
||||
# Don't waste stack space on saved arguments.
|
||||
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
||||
|
||||
if fun_type.async:
|
||||
return ll.Undefined
|
||||
|
||||
# T result = {
|
||||
# void *ptr = NULL;
|
||||
# loop: int size = rpc_recv("tag", ptr);
|
||||
|
@ -308,20 +308,22 @@ class TRPC(Type):
|
||||
:ivar ret: (:class:`Type`)
|
||||
return type
|
||||
:ivar service: (int) RPC service number
|
||||
:ivar async: (bool) whether the RPC blocks until return
|
||||
"""
|
||||
|
||||
attributes = OrderedDict()
|
||||
|
||||
def __init__(self, ret, service):
|
||||
def __init__(self, ret, service, async=False):
|
||||
assert isinstance(ret, Type)
|
||||
self.ret, self.service = ret, service
|
||||
self.ret, self.service, self.async = ret, service, async
|
||||
|
||||
def find(self):
|
||||
return self
|
||||
|
||||
def unify(self, other):
|
||||
if isinstance(other, TRPC) and \
|
||||
self.service == other.service:
|
||||
self.service == other.service and \
|
||||
self.async == other.async:
|
||||
self.ret.unify(other.ret)
|
||||
elif isinstance(other, TVar):
|
||||
other.unify(self)
|
||||
@ -337,7 +339,8 @@ class TRPC(Type):
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TRPC) and \
|
||||
self.service == other.service
|
||||
self.service == other.service and \
|
||||
self.async == other.async
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
@ -727,7 +730,9 @@ class TypePrinter(object):
|
||||
elif isinstance(typ, TFunction):
|
||||
return signature
|
||||
elif isinstance(typ, TRPC):
|
||||
return "[rpc #{}](...)->{}".format(typ.service, self.name(typ.ret, depth + 1))
|
||||
return "[rpc{} #{}](...)->{}".format(typ.service,
|
||||
" async" if typ.async else "",
|
||||
self.name(typ.ret, depth + 1))
|
||||
elif isinstance(typ, TBuiltinFunction):
|
||||
return "<function {}>".format(typ.name)
|
||||
elif isinstance(typ, (TConstructor, TExceptionConstructor)):
|
||||
|
@ -7,7 +7,7 @@ from functools import wraps
|
||||
import numpy
|
||||
|
||||
|
||||
__all__ = ["kernel", "portable", "syscall", "host_only",
|
||||
__all__ = ["kernel", "portable", "rpc", "syscall", "host_only",
|
||||
"set_time_manager", "set_watchdog_factory",
|
||||
"TerminationRequested"]
|
||||
|
||||
@ -22,7 +22,7 @@ __all__.extend(kernel_globals)
|
||||
|
||||
|
||||
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
|
||||
"core_name function syscall forbidden flags")
|
||||
"core_name portable function syscall forbidden flags")
|
||||
|
||||
def kernel(arg=None, flags={}):
|
||||
"""
|
||||
@ -53,7 +53,7 @@ def kernel(arg=None, flags={}):
|
||||
def run_on_core(self, *k_args, **k_kwargs):
|
||||
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
|
||||
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
|
||||
core_name=arg, function=function, syscall=None,
|
||||
core_name=arg, portable=False, function=function, syscall=None,
|
||||
forbidden=False, flags=set(flags))
|
||||
return run_on_core
|
||||
return inner_decorator
|
||||
@ -83,7 +83,23 @@ def portable(arg=None, flags={}):
|
||||
return inner_decorator
|
||||
else:
|
||||
arg.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, function=arg, syscall=None,
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None,
|
||||
forbidden=False, flags=set(flags))
|
||||
return arg
|
||||
|
||||
def rpc(arg=None, flags={}):
|
||||
"""
|
||||
This decorator marks a function for execution on the host interpreter.
|
||||
This is also the default behavior of ARTIQ; however, this decorator allows
|
||||
specifying additional flags.
|
||||
"""
|
||||
if arg is None:
|
||||
def inner_decorator(function):
|
||||
return rpc(function, flags)
|
||||
return inner_decorator
|
||||
else:
|
||||
arg.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None,
|
||||
forbidden=False, flags=set(flags))
|
||||
return arg
|
||||
|
||||
@ -101,7 +117,7 @@ def syscall(arg=None, flags={}):
|
||||
if isinstance(arg, str):
|
||||
def inner_decorator(function):
|
||||
function.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, function=None,
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None,
|
||||
syscall=function.__name__, forbidden=False,
|
||||
flags=set(flags))
|
||||
return function
|
||||
@ -119,7 +135,7 @@ def host_only(function):
|
||||
in the host Python interpreter.
|
||||
"""
|
||||
function.artiq_embedded = \
|
||||
_ARTIQEmbeddedInfo(core_name=None, function=None, syscall=None,
|
||||
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None,
|
||||
forbidden=True, flags={})
|
||||
return function
|
||||
|
||||
|
15
artiq/test/lit/embedding/async_rpc.py
Normal file
15
artiq/test/lit/embedding/async_rpc.py
Normal file
@ -0,0 +1,15 @@
|
||||
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
||||
# CHECK: call void @send_async_rpc
|
||||
|
||||
@rpc(flags={"async"})
|
||||
def foo():
|
||||
pass
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
foo()
|
15
artiq/test/lit/embedding/error_rpc_async_return.py
Normal file
15
artiq/test/lit/embedding/error_rpc_async_return.py
Normal file
@ -0,0 +1,15 @@
|
||||
# RUN: %python -m artiq.compiler.testbench.embedding +diag %s 2>%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
from artiq.language.core import *
|
||||
from artiq.language.types import *
|
||||
|
||||
# CHECK-L: ${LINE:+2}: fatal: functions that return a value cannot be defined as async RPCs
|
||||
@rpc(flags={"async"})
|
||||
def foo() -> TInt32:
|
||||
pass
|
||||
|
||||
@kernel
|
||||
def entrypoint():
|
||||
# CHECK-L: ${LINE:+1}: note: function called here
|
||||
foo()
|
@ -47,6 +47,19 @@ The Python types correspond to ARTIQ type annotations as follows:
|
||||
| range | TRange32, TRange64 |
|
||||
+-------------+-------------------------+
|
||||
|
||||
Asynchronous RPCs
|
||||
-----------------
|
||||
|
||||
If an RPC returns no value, it can be invoked in a way that does not block until the RPC finishes
|
||||
execution, but only until it is queued. (Submitting asynchronous RPCs too rapidly, as well as
|
||||
submitting asynchronous RPCs with arguments that are too large, can still block until completion.)
|
||||
|
||||
To define an asynchronous RPC, use the ``@rpc`` annotation with a flag:
|
||||
|
||||
@rpc(flags={"async"})
|
||||
def record_result(x):
|
||||
self.results.append(x)
|
||||
|
||||
Additional optimizations
|
||||
------------------------
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user