forked from M-Labs/artiq
compiler: add support for async RPCs.
This commit is contained in:
parent
2ac85cd40f
commit
cd68577dbc
|
@ -994,8 +994,24 @@ class Stitcher:
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
is_async = False
|
||||||
|
if hasattr(host_function, "artiq_embedded") and \
|
||||||
|
"async" in host_function.artiq_embedded.flags:
|
||||||
|
is_async = True
|
||||||
|
|
||||||
|
if not builtins.is_none(ret_type) and is_async:
|
||||||
|
note = diagnostic.Diagnostic("note",
|
||||||
|
"function called here", {},
|
||||||
|
loc)
|
||||||
|
diag = diagnostic.Diagnostic("fatal",
|
||||||
|
"functions that return a value cannot be defined as async RPCs", {},
|
||||||
|
self._function_loc(host_function.artiq_embedded.function),
|
||||||
|
notes=[note])
|
||||||
|
self.engine.process(diag)
|
||||||
|
|
||||||
function_type = types.TRPC(ret_type,
|
function_type = types.TRPC(ret_type,
|
||||||
service=self.embedding_map.store_object(host_function))
|
service=self.embedding_map.store_object(host_function),
|
||||||
|
async=is_async)
|
||||||
self.functions[function] = function_type
|
self.functions[function] = function_type
|
||||||
return function_type
|
return function_type
|
||||||
|
|
||||||
|
@ -1007,7 +1023,11 @@ class Stitcher:
|
||||||
|
|
||||||
if function in self.functions:
|
if function in self.functions:
|
||||||
pass
|
pass
|
||||||
elif not hasattr(host_function, "artiq_embedded"):
|
elif not hasattr(host_function, "artiq_embedded") or \
|
||||||
|
(host_function.artiq_embedded.core_name is None and
|
||||||
|
host_function.artiq_embedded.portable is False and
|
||||||
|
host_function.artiq_embedded.syscall is None and
|
||||||
|
host_function.artiq_embedded.forbidden is False):
|
||||||
self._quote_rpc(function, loc)
|
self._quote_rpc(function, loc)
|
||||||
elif host_function.artiq_embedded.function is not None:
|
elif host_function.artiq_embedded.function is not None:
|
||||||
if host_function.__name__ == "<lambda>":
|
if host_function.__name__ == "<lambda>":
|
||||||
|
|
|
@ -31,6 +31,7 @@ def globals():
|
||||||
# ARTIQ decorators
|
# ARTIQ decorators
|
||||||
"kernel": builtins.fn_kernel(),
|
"kernel": builtins.fn_kernel(),
|
||||||
"portable": builtins.fn_kernel(),
|
"portable": builtins.fn_kernel(),
|
||||||
|
"rpc": builtins.fn_kernel(),
|
||||||
|
|
||||||
# ARTIQ context managers
|
# ARTIQ context managers
|
||||||
"parallel": builtins.obj_parallel(),
|
"parallel": builtins.obj_parallel(),
|
||||||
|
|
|
@ -509,7 +509,7 @@ class ASTTypedRewriter(algorithm.Transformer):
|
||||||
visit_DictComp = visit_unsupported
|
visit_DictComp = visit_unsupported
|
||||||
visit_Ellipsis = visit_unsupported
|
visit_Ellipsis = visit_unsupported
|
||||||
visit_GeneratorExp = visit_unsupported
|
visit_GeneratorExp = visit_unsupported
|
||||||
visit_Set = visit_unsupported
|
# visit_Set = visit_unsupported
|
||||||
visit_SetComp = visit_unsupported
|
visit_SetComp = visit_unsupported
|
||||||
visit_Starred = visit_unsupported
|
visit_Starred = visit_unsupported
|
||||||
visit_Yield = visit_unsupported
|
visit_Yield = visit_unsupported
|
||||||
|
|
|
@ -351,6 +351,8 @@ class LLVMIRGenerator:
|
||||||
llty = ll.FunctionType(lli32, [lldouble])
|
llty = ll.FunctionType(lli32, [lldouble])
|
||||||
elif name == "send_rpc":
|
elif name == "send_rpc":
|
||||||
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
|
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
|
||||||
|
elif name == "send_async_rpc":
|
||||||
|
llty = ll.FunctionType(llvoid, [lli32, llptr, llptrptr])
|
||||||
elif name == "recv_rpc":
|
elif name == "recv_rpc":
|
||||||
llty = ll.FunctionType(lli32, [llptr])
|
llty = ll.FunctionType(lli32, [llptr])
|
||||||
elif name == "now":
|
elif name == "now":
|
||||||
|
@ -366,7 +368,8 @@ class LLVMIRGenerator:
|
||||||
llglobal = ll.Function(self.llmodule, llty, name)
|
llglobal = ll.Function(self.llmodule, llty, name)
|
||||||
if name in ("__artiq_raise", "__artiq_reraise", "llvm.trap"):
|
if name in ("__artiq_raise", "__artiq_reraise", "llvm.trap"):
|
||||||
llglobal.attributes.add("noreturn")
|
llglobal.attributes.add("noreturn")
|
||||||
if name in ("rtio_log", "send_rpc", "watchdog_set", "watchdog_clear",
|
if name in ("rtio_log", "send_rpc", "send_async_rpc",
|
||||||
|
"watchdog_set", "watchdog_clear",
|
||||||
self.target.print_function):
|
self.target.print_function):
|
||||||
llglobal.attributes.add("nounwind")
|
llglobal.attributes.add("nounwind")
|
||||||
else:
|
else:
|
||||||
|
@ -1248,12 +1251,19 @@ class LLVMIRGenerator:
|
||||||
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
|
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
|
||||||
self.llbuilder.store(llargslot, llargptr)
|
self.llbuilder.store(llargslot, llargptr)
|
||||||
|
|
||||||
|
if fun_type.async:
|
||||||
|
self.llbuilder.call(self.llbuiltin("send_async_rpc"),
|
||||||
|
[llservice, lltag, llargs])
|
||||||
|
else:
|
||||||
self.llbuilder.call(self.llbuiltin("send_rpc"),
|
self.llbuilder.call(self.llbuiltin("send_rpc"),
|
||||||
[llservice, lltag, llargs])
|
[llservice, lltag, llargs])
|
||||||
|
|
||||||
# Don't waste stack space on saved arguments.
|
# Don't waste stack space on saved arguments.
|
||||||
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
||||||
|
|
||||||
|
if fun_type.async:
|
||||||
|
return ll.Undefined
|
||||||
|
|
||||||
# T result = {
|
# T result = {
|
||||||
# void *ptr = NULL;
|
# void *ptr = NULL;
|
||||||
# loop: int size = rpc_recv("tag", ptr);
|
# loop: int size = rpc_recv("tag", ptr);
|
||||||
|
|
|
@ -308,20 +308,22 @@ class TRPC(Type):
|
||||||
:ivar ret: (:class:`Type`)
|
:ivar ret: (:class:`Type`)
|
||||||
return type
|
return type
|
||||||
:ivar service: (int) RPC service number
|
:ivar service: (int) RPC service number
|
||||||
|
:ivar async: (bool) whether the RPC blocks until return
|
||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = OrderedDict()
|
attributes = OrderedDict()
|
||||||
|
|
||||||
def __init__(self, ret, service):
|
def __init__(self, ret, service, async=False):
|
||||||
assert isinstance(ret, Type)
|
assert isinstance(ret, Type)
|
||||||
self.ret, self.service = ret, service
|
self.ret, self.service, self.async = ret, service, async
|
||||||
|
|
||||||
def find(self):
|
def find(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def unify(self, other):
|
def unify(self, other):
|
||||||
if isinstance(other, TRPC) and \
|
if isinstance(other, TRPC) and \
|
||||||
self.service == other.service:
|
self.service == other.service and \
|
||||||
|
self.async == other.async:
|
||||||
self.ret.unify(other.ret)
|
self.ret.unify(other.ret)
|
||||||
elif isinstance(other, TVar):
|
elif isinstance(other, TVar):
|
||||||
other.unify(self)
|
other.unify(self)
|
||||||
|
@ -337,7 +339,8 @@ class TRPC(Type):
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, TRPC) and \
|
return isinstance(other, TRPC) and \
|
||||||
self.service == other.service
|
self.service == other.service and \
|
||||||
|
self.async == other.async
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not (self == other)
|
return not (self == other)
|
||||||
|
@ -727,7 +730,9 @@ class TypePrinter(object):
|
||||||
elif isinstance(typ, TFunction):
|
elif isinstance(typ, TFunction):
|
||||||
return signature
|
return signature
|
||||||
elif isinstance(typ, TRPC):
|
elif isinstance(typ, TRPC):
|
||||||
return "[rpc #{}](...)->{}".format(typ.service, self.name(typ.ret, depth + 1))
|
return "[rpc{} #{}](...)->{}".format(typ.service,
|
||||||
|
" async" if typ.async else "",
|
||||||
|
self.name(typ.ret, depth + 1))
|
||||||
elif isinstance(typ, TBuiltinFunction):
|
elif isinstance(typ, TBuiltinFunction):
|
||||||
return "<function {}>".format(typ.name)
|
return "<function {}>".format(typ.name)
|
||||||
elif isinstance(typ, (TConstructor, TExceptionConstructor)):
|
elif isinstance(typ, (TConstructor, TExceptionConstructor)):
|
||||||
|
|
|
@ -7,7 +7,7 @@ from functools import wraps
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["kernel", "portable", "syscall", "host_only",
|
__all__ = ["kernel", "portable", "rpc", "syscall", "host_only",
|
||||||
"set_time_manager", "set_watchdog_factory",
|
"set_time_manager", "set_watchdog_factory",
|
||||||
"TerminationRequested"]
|
"TerminationRequested"]
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ __all__.extend(kernel_globals)
|
||||||
|
|
||||||
|
|
||||||
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
|
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo",
|
||||||
"core_name function syscall forbidden flags")
|
"core_name portable function syscall forbidden flags")
|
||||||
|
|
||||||
def kernel(arg=None, flags={}):
|
def kernel(arg=None, flags={}):
|
||||||
"""
|
"""
|
||||||
|
@ -53,7 +53,7 @@ def kernel(arg=None, flags={}):
|
||||||
def run_on_core(self, *k_args, **k_kwargs):
|
def run_on_core(self, *k_args, **k_kwargs):
|
||||||
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
|
return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs)
|
||||||
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
|
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
|
||||||
core_name=arg, function=function, syscall=None,
|
core_name=arg, portable=False, function=function, syscall=None,
|
||||||
forbidden=False, flags=set(flags))
|
forbidden=False, flags=set(flags))
|
||||||
return run_on_core
|
return run_on_core
|
||||||
return inner_decorator
|
return inner_decorator
|
||||||
|
@ -83,7 +83,23 @@ def portable(arg=None, flags={}):
|
||||||
return inner_decorator
|
return inner_decorator
|
||||||
else:
|
else:
|
||||||
arg.artiq_embedded = \
|
arg.artiq_embedded = \
|
||||||
_ARTIQEmbeddedInfo(core_name=None, function=arg, syscall=None,
|
_ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None,
|
||||||
|
forbidden=False, flags=set(flags))
|
||||||
|
return arg
|
||||||
|
|
||||||
|
def rpc(arg=None, flags={}):
|
||||||
|
"""
|
||||||
|
This decorator marks a function for execution on the host interpreter.
|
||||||
|
This is also the default behavior of ARTIQ; however, this decorator allows
|
||||||
|
specifying additional flags.
|
||||||
|
"""
|
||||||
|
if arg is None:
|
||||||
|
def inner_decorator(function):
|
||||||
|
return rpc(function, flags)
|
||||||
|
return inner_decorator
|
||||||
|
else:
|
||||||
|
arg.artiq_embedded = \
|
||||||
|
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None,
|
||||||
forbidden=False, flags=set(flags))
|
forbidden=False, flags=set(flags))
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
|
@ -101,7 +117,7 @@ def syscall(arg=None, flags={}):
|
||||||
if isinstance(arg, str):
|
if isinstance(arg, str):
|
||||||
def inner_decorator(function):
|
def inner_decorator(function):
|
||||||
function.artiq_embedded = \
|
function.artiq_embedded = \
|
||||||
_ARTIQEmbeddedInfo(core_name=None, function=None,
|
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None,
|
||||||
syscall=function.__name__, forbidden=False,
|
syscall=function.__name__, forbidden=False,
|
||||||
flags=set(flags))
|
flags=set(flags))
|
||||||
return function
|
return function
|
||||||
|
@ -119,7 +135,7 @@ def host_only(function):
|
||||||
in the host Python interpreter.
|
in the host Python interpreter.
|
||||||
"""
|
"""
|
||||||
function.artiq_embedded = \
|
function.artiq_embedded = \
|
||||||
_ARTIQEmbeddedInfo(core_name=None, function=None, syscall=None,
|
_ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None,
|
||||||
forbidden=True, flags={})
|
forbidden=True, flags={})
|
||||||
return function
|
return function
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# RUN: env ARTIQ_DUMP_LLVM=%t %python -m artiq.compiler.testbench.embedding +compile %s
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t.ll
|
||||||
|
|
||||||
|
from artiq.language.core import *
|
||||||
|
from artiq.language.types import *
|
||||||
|
|
||||||
|
# CHECK: call void @send_async_rpc
|
||||||
|
|
||||||
|
@rpc(flags={"async"})
|
||||||
|
def foo():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def entrypoint():
|
||||||
|
foo()
|
|
@ -0,0 +1,15 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.embedding +diag %s 2>%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
from artiq.language.core import *
|
||||||
|
from artiq.language.types import *
|
||||||
|
|
||||||
|
# CHECK-L: ${LINE:+2}: fatal: functions that return a value cannot be defined as async RPCs
|
||||||
|
@rpc(flags={"async"})
|
||||||
|
def foo() -> TInt32:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def entrypoint():
|
||||||
|
# CHECK-L: ${LINE:+1}: note: function called here
|
||||||
|
foo()
|
|
@ -47,6 +47,19 @@ The Python types correspond to ARTIQ type annotations as follows:
|
||||||
| range | TRange32, TRange64 |
|
| range | TRange32, TRange64 |
|
||||||
+-------------+-------------------------+
|
+-------------+-------------------------+
|
||||||
|
|
||||||
|
Asynchronous RPCs
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
If an RPC returns no value, it can be invoked in a way that does not block until the RPC finishes
|
||||||
|
execution, but only until it is queued. (Submitting asynchronous RPCs too rapidly, as well as
|
||||||
|
submitting asynchronous RPCs with arguments that are too large, can still block until completion.)
|
||||||
|
|
||||||
|
To define an asynchronous RPC, use the ``@rpc`` annotation with a flag:
|
||||||
|
|
||||||
|
@rpc(flags={"async"})
|
||||||
|
def record_result(x):
|
||||||
|
self.results.append(x)
|
||||||
|
|
||||||
Additional optimizations
|
Additional optimizations
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue