From 0a750c77e8f432c6eac7b0ee5d247276eeabd393 Mon Sep 17 00:00:00 2001 From: mwojcik Date: Thu, 5 Oct 2023 14:35:50 +0800 Subject: [PATCH] compiler: support subkernels --- artiq/compiler/asttyped.py | 10 +- artiq/compiler/builtins.py | 15 +- artiq/compiler/embedding.py | 109 ++++++- artiq/compiler/ir.py | 66 +++- artiq/compiler/module.py | 2 + artiq/compiler/prelude.py | 5 + artiq/compiler/targets.py | 10 +- .../compiler/transforms/artiq_ir_generator.py | 87 ++++- .../compiler/transforms/asttyped_rewriter.py | 7 +- artiq/compiler/transforms/inferencer.py | 66 +++- .../compiler/transforms/iodelay_estimator.py | 25 +- .../compiler/transforms/llvm_ir_generator.py | 296 ++++++++++++++---- artiq/compiler/types.py | 51 +++ artiq/frontend/artiq_compile.py | 4 +- artiq/language/core.py | 58 +++- 15 files changed, 699 insertions(+), 112 deletions(-) diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py index 10b197fa4..b6fb34274 100644 --- a/artiq/compiler/asttyped.py +++ b/artiq/compiler/asttyped.py @@ -21,13 +21,19 @@ class scoped(object): set of variables resolved as globals """ +class remote(object): + """ + :ivar remote_fn: (bool) whether function is ran on a remote device, + meaning arguments are received remotely and return is sent remotely + """ + # Typed versions of untyped nodes class argT(ast.arg, commontyped): pass class ClassDefT(ast.ClassDef): _types = ("constructor_type",) -class FunctionDefT(ast.FunctionDef, scoped): +class FunctionDefT(ast.FunctionDef, scoped, remote): _types = ("signature_type",) class QuotedFunctionDefT(FunctionDefT): """ @@ -58,7 +64,7 @@ class BinOpT(ast.BinOp, commontyped): pass class BoolOpT(ast.BoolOp, commontyped): pass -class CallT(ast.Call, commontyped): +class CallT(ast.Call, commontyped, remote): """ :ivar iodelay: (:class:`iodelay.Expr`) :ivar arg_exprs: (dict of str to :class:`iodelay.Expr`) diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index fdd5286e1..64e9b3690 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -38,6 +38,9 @@ class TInt(types.TMono): def one(): return 1 +def TInt8(): + return TInt(types.TValue(8)) + def TInt32(): return TInt(types.TValue(32)) @@ -244,6 +247,12 @@ def fn_at_mu(): def fn_rtio_log(): return types.TBuiltinFunction("rtio_log") +def fn_subkernel_await(): + return types.TBuiltinFunction("subkernel_await") + +def fn_subkernel_preload(): + return types.TBuiltinFunction("subkernel_preload") + # Accessors def is_none(typ): @@ -326,7 +335,7 @@ def get_iterable_elt(typ): # n-dimensional arrays, rather than the n-1 dimensional result of iterating over # the first axis, which makes the name a bit misleading. if is_str(typ) or is_bytes(typ) or is_bytearray(typ): - return TInt(types.TValue(8)) + return TInt8() elif types._is_pointer(typ) or is_iterable(typ): return typ.find()["elt"].find() else: @@ -342,5 +351,5 @@ def is_allocated(typ): is_float(typ) or is_range(typ) or types._is_pointer(typ) or types.is_function(typ) or types.is_external_function(typ) or types.is_rpc(typ) or - types.is_method(typ) or types.is_tuple(typ) or - types.is_value(typ)) + types.is_subkernel(typ) or types.is_method(typ) or + types.is_tuple(typ) or types.is_value(typ)) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 040fc80ee..9c2f270d8 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -74,7 +74,9 @@ class EmbeddingMap: "CacheError", "SPIError", "0:ZeroDivisionError", - "0:IndexError"]) + "0:IndexError", + "UnwrapNoneError", + "SubkernelError"]) def preallocate_runtime_exception_names(self, names): for i, name in enumerate(names): @@ -183,7 +185,15 @@ class EmbeddingMap: obj_typ, _ = self.type_map[type(obj_ref)] yield obj_id, obj_ref, obj_typ - def has_rpc(self): + def subkernels(self): + subkernels = {} + for k, v in self.object_forward_map.items(): + if hasattr(v, "artiq_embedded"): + if v.artiq_embedded.destination is not None: + subkernels[k] = v + return subkernels + + def has_rpc_or_subkernel(self): return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x), self.object_forward_map.values())) @@ -469,7 +479,7 @@ class ASTSynthesizer: return asttyped.QuoteT(value=value, type=instance_type, loc=loc) - def call(self, callee, args, kwargs, callback=None): + def call(self, callee, args, kwargs, callback=None, remote_fn=False): """ Construct an AST fragment calling a function specified by an AST node `function_node`, with given arguments. @@ -513,7 +523,7 @@ class ASTSynthesizer: starargs=None, kwargs=None, type=types.TVar(), iodelay=None, arg_exprs={}, begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, - loc=callee_node.loc.join(end_loc)) + loc=callee_node.loc.join(end_loc), remote_fn=remote_fn) if callback is not None: node = asttyped.CallT( @@ -548,7 +558,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): arg=node.arg, annotation=None, arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) - def visit_quoted_function(self, node, function): + def visit_quoted_function(self, node, function, remote_fn): extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) extractor.visit(node) @@ -569,7 +579,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter): body=node.body, decorator_list=node.decorator_list, keyword_loc=node.keyword_loc, name_loc=node.name_loc, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, - loc=node.loc) + loc=node.loc, remote_fn=remote_fn) try: self.env_stack.append(node.typing_env) @@ -777,7 +787,7 @@ class TypedtreeHasher(algorithm.Visitor): return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields)) class Stitcher: - def __init__(self, core, dmgr, engine=None, print_as_rpc=True): + def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[]): self.core = core self.dmgr = dmgr if engine is None: @@ -803,11 +813,19 @@ class Stitcher: self.value_map = defaultdict(lambda: []) self.definitely_changed = False + self.destination = destination + self.first_call = True + # for non-annotated subkernels: + # main kernel inferencer output with types of arguments + self.subkernel_arg_types = subkernel_arg_types + def stitch_call(self, function, args, kwargs, callback=None): # We synthesize source code for the initial call so that # diagnostics would have something meaningful to display to the user. synthesizer = self._synthesizer(self._function_loc(function.artiq_embedded.function)) - call_node = synthesizer.call(function, args, kwargs, callback) + # first call of a subkernel will get its arguments from remote (DRTIO) + remote_fn = self.destination != 0 + call_node = synthesizer.call(function, args, kwargs, callback, remote_fn=remote_fn) synthesizer.finalize() self.typedtree.append(call_node) @@ -919,6 +937,10 @@ class Stitcher: return [diagnostic.Diagnostic("note", "in kernel function here", {}, call_loc)] + elif fn_kind == 'subkernel': + return [diagnostic.Diagnostic("note", + "in subkernel call here", {}, + call_loc)] else: assert False else: @@ -938,7 +960,7 @@ class Stitcher: self._function_loc(function), notes=self._call_site_note(loc, fn_kind)) self.engine.process(diag) - elif fn_kind == 'rpc' and param.default is not inspect.Parameter.empty: + elif fn_kind == 'rpc' or fn_kind == 'subkernel' and param.default is not inspect.Parameter.empty: notes = [] notes.append(diagnostic.Diagnostic("note", "expanded from here while trying to infer a type for an" @@ -957,11 +979,18 @@ class Stitcher: Inferencer(engine=self.engine).visit(ast) IntMonomorphizer(engine=self.engine).visit(ast) return ast.type - else: - # Let the rest of the program decide. - return types.TVar() + elif fn_kind == 'kernel' and self.first_call and self.destination != 0: + # subkernels do not have access to the main kernel code to infer + # arg types - so these are cached and passed onto subkernel + # compilation, to avoid having to annotate them fully + for name, typ in self.subkernel_arg_types: + if param.name == name: + return typ - def _quote_embedded_function(self, function, flags): + # Let the rest of the program decide. + return types.TVar() + + def _quote_embedded_function(self, function, flags, remote_fn=False): # we are now parsing new functions... definitely changed the type self.definitely_changed = True @@ -1060,7 +1089,7 @@ class Stitcher: engine=self.engine, prelude=self.prelude, globals=self.globals, host_environment=host_environment, quote=self._quote) - function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function) + function_node = asttyped_rewriter.visit_quoted_function(function_node, embedded_function, remote_fn) function_node.flags = flags # Add it into our typedtree so that it gets inferenced and codegen'd. @@ -1174,7 +1203,6 @@ class Stitcher: signature = inspect.signature(function) arg_types = OrderedDict() - optarg_types = OrderedDict() for param in signature.parameters.values(): if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: diag = diagnostic.Diagnostic("error", @@ -1212,6 +1240,40 @@ class Stitcher: self.functions[function] = function_type return function_type + def _quote_subkernel(self, function, loc): + if isinstance(function, SpecializedFunction): + host_function = function.host_function + else: + host_function = function + ret_type = builtins.TNone() + signature = inspect.signature(host_function) + + if signature.return_annotation is not inspect.Signature.empty: + ret_type = self._extract_annot(host_function, signature.return_annotation, + "return type", loc, fn_kind='subkernel') + arg_types = OrderedDict() + optarg_types = OrderedDict() + for param in signature.parameters.values(): + if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: + diag = diagnostic.Diagnostic("error", + "subkernels must only use positional arguments; '{argument}' isn't", + {"argument": param.name}, + self._function_loc(function), + notes=self._call_site_note(loc, fn_kind='subkernel')) + self.engine.process(diag) + + arg_type = self._type_of_param(function, loc, param, fn_kind='subkernel') + if param.default is inspect.Parameter.empty: + arg_types[param.name] = arg_type + else: + optarg_types[param.name] = arg_type + + function_type = types.TSubkernel(arg_types, optarg_types, ret_type, + sid=self.embedding_map.store_object(host_function), + destination=host_function.artiq_embedded.destination) + self.functions[function] = function_type + return function_type + def _quote_rpc(self, function, loc): if isinstance(function, SpecializedFunction): host_function = function.host_function @@ -1271,8 +1333,18 @@ class Stitcher: (host_function.artiq_embedded.core_name is None and host_function.artiq_embedded.portable is False and host_function.artiq_embedded.syscall is None and + host_function.artiq_embedded.destination is None and host_function.artiq_embedded.forbidden is False): self._quote_rpc(function, loc) + elif host_function.artiq_embedded.destination is not None and \ + host_function.artiq_embedded.destination != self.destination: + # treat subkernels as kernels if running on the same device + if not 0 < host_function.artiq_embedded.destination <= 255: + diag = diagnostic.Diagnostic("error", + "subkernel destination must be between 1 and 255 (inclusive)", {}, + self._function_loc(host_function)) + self.engine.process(diag) + self._quote_subkernel(function, loc) elif host_function.artiq_embedded.function is not None: if host_function.__name__ == "": note = diagnostic.Diagnostic("note", @@ -1296,8 +1368,13 @@ class Stitcher: notes=[note]) self.engine.process(diag) + destination = host_function.artiq_embedded.destination + # remote_fn only for first call in subkernels + remote_fn = destination is not None and self.first_call self._quote_embedded_function(function, - flags=host_function.artiq_embedded.flags) + flags=host_function.artiq_embedded.flags, + remote_fn=remote_fn) + self.first_call = False elif host_function.artiq_embedded.syscall is not None: # Insert a storage-less global whose type instructs the compiler # to perform a system call instead of a regular call. diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 88ef3a151..3af11ccd0 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -706,6 +706,64 @@ class SetLocal(Instruction): def value(self): return self.operands[1] +class GetArgFromRemote(Instruction): + """ + An instruction that receives function arguments from remote + (ie. subkernel in DRTIO context) + + :ivar arg_name: (string) argument name + :ivar arg_type: argument type + """ + + """ + :param arg_name: (string) argument name + :param arg_type: argument type + """ + def __init__(self, arg_name, arg_type, name=""): + assert isinstance(arg_name, str) + super().__init__([], arg_type, name) + self.arg_name = arg_name + self.arg_type = arg_type + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.arg_name = self.arg_name + self_copy.arg_type = self.arg_type + return self_copy + + def opcode(self): + return "getargfromremote({})".format(repr(self.arg_name)) + +class GetOptArgFromRemote(GetArgFromRemote): + """ + An instruction that may or may not retrieve an optional function argument + from remote, depending on number of values received by firmware. + + :ivar rcv_count: number of received values, + determined by firmware + :ivar index: (integer) index of the current argument, + in reference to remote arguments + """ + + """ + :param rcv_count: number of received valuese + :param index: (integer) index of the current argument, + in reference to remote arguments + """ + def __init__(self, arg_name, arg_type, rcv_count, index, name=""): + super().__init__(arg_name, arg_type, name) + self.rcv_count = rcv_count + self.index = index + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.rcv_count = self.rcv_count + self_copy.index = self.index + return self_copy + + def opcode(self): + return "getoptargfromremote({})".format(repr(self.arg_name)) + class GetAttr(Instruction): """ An intruction that loads an attribute from an object, @@ -728,7 +786,7 @@ class GetAttr(Instruction): typ = obj.type.attributes[attr] else: typ = obj.type.constructor.attributes[attr] - if types.is_function(typ) or types.is_rpc(typ): + if types.is_function(typ) or types.is_rpc(typ) or types.is_subkernel(typ): typ = types.TMethod(obj.type, typ) super().__init__([obj], typ, name) self.attr = attr @@ -1190,14 +1248,18 @@ class IndirectBranch(Terminator): class Return(Terminator): """ A return instruction. + :param remote_return: (bool) + marks a return in subkernel context, + where the return value is sent back through DRTIO """ """ :param value: (:class:`Value`) return value """ - def __init__(self, value, name=""): + def __init__(self, value, remote_return=False, name=""): assert isinstance(value, Value) super().__init__([value], builtins.TNone(), name) + self.remote_return = remote_return def opcode(self): return "return" diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py index f3bc35cde..cfac4e26e 100644 --- a/artiq/compiler/module.py +++ b/artiq/compiler/module.py @@ -84,6 +84,8 @@ class Module: constant_hoister.process(self.artiq_ir) if remarks: invariant_detection.process(self.artiq_ir) + # for subkernels: main kernel inferencer output, to be passed to further compilations + self.subkernel_arg_types = inferencer.subkernel_arg_types def build_llvm_ir(self, target): """Compile the module to LLVM IR for the specified target.""" diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index 13f319650..effbca87c 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -37,6 +37,7 @@ def globals(): # ARTIQ decorators "kernel": builtins.fn_kernel(), + "subkernel": builtins.fn_kernel(), "portable": builtins.fn_kernel(), "rpc": builtins.fn_kernel(), @@ -54,4 +55,8 @@ def globals(): # ARTIQ utility functions "rtio_log": builtins.fn_rtio_log(), "core_log": builtins.fn_print(), + + # ARTIQ subkernel utility functions + "subkernel_await": builtins.fn_subkernel_await(), + "subkernel_preload": builtins.fn_subkernel_preload(), } diff --git a/artiq/compiler/targets.py b/artiq/compiler/targets.py index 0dd835a0a..5f043eb0e 100644 --- a/artiq/compiler/targets.py +++ b/artiq/compiler/targets.py @@ -94,8 +94,9 @@ class Target: tool_symbolizer = "llvm-symbolizer" tool_cxxfilt = "llvm-cxxfilt" - def __init__(self): + def __init__(self, subkernel_id=None): self.llcontext = ll.Context() + self.subkernel_id = subkernel_id def target_machine(self): lltarget = llvm.Target.from_triple(self.triple) @@ -148,7 +149,8 @@ class Target: ir.BasicBlock._dump_loc = False type_printer = types.TypePrinter() - _dump(os.getenv("ARTIQ_DUMP_IR"), "ARTIQ IR", ".txt", + suffix = "_subkernel_{}".format(self.subkernel_id) if self.subkernel_id is not None else "" + _dump(os.getenv("ARTIQ_DUMP_IR"), "ARTIQ IR", suffix + ".txt", lambda: "\n".join(fn.as_entity(type_printer) for fn in module.artiq_ir)) llmod = module.build_llvm_ir(self) @@ -160,12 +162,12 @@ class Target: _dump("", "LLVM IR (broken)", ".ll", lambda: str(llmod)) raise - _dump(os.getenv("ARTIQ_DUMP_UNOPT_LLVM"), "LLVM IR (generated)", "_unopt.ll", + _dump(os.getenv("ARTIQ_DUMP_UNOPT_LLVM"), "LLVM IR (generated)", suffix + "_unopt.ll", lambda: str(llparsedmod)) self.optimize(llparsedmod) - _dump(os.getenv("ARTIQ_DUMP_LLVM"), "LLVM IR (optimized)", ".ll", + _dump(os.getenv("ARTIQ_DUMP_LLVM"), "LLVM IR (optimized)", suffix + ".ll", lambda: str(llparsedmod)) return llparsedmod diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 7ede45531..489739ba7 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -108,6 +108,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.current_args = None self.current_assign = None self.current_exception = None + self.current_remote_fn = False self.break_target = None self.continue_target = None self.return_target = None @@ -211,7 +212,8 @@ class ARTIQIRGenerator(algorithm.Visitor): old_priv_env, self.current_private_env = self.current_private_env, priv_env self.generic_visit(node) - self.terminate(ir.Return(ir.Constant(None, builtins.TNone()))) + self.terminate(ir.Return(ir.Constant(None, builtins.TNone()), + remote_return=self.current_remote_fn)) return self.functions finally: @@ -294,6 +296,8 @@ class ARTIQIRGenerator(algorithm.Visitor): old_block, self.current_block = self.current_block, entry old_globals, self.current_globals = self.current_globals, node.globals_in_scope + old_remote_fn = self.current_remote_fn + self.current_remote_fn = getattr(node, "remote_fn", False) env_without_globals = \ {var: node.typing_env[var] @@ -326,7 +330,8 @@ class ARTIQIRGenerator(algorithm.Visitor): self.terminate(ir.Return(result)) elif builtins.is_none(typ.ret): if not self.current_block.is_terminated(): - self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone()))) + self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone()), + remote_return=self.current_remote_fn)) else: if not self.current_block.is_terminated(): if len(self.current_block.predecessors()) != 0: @@ -345,6 +350,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.current_block = old_block self.current_globals = old_globals self.current_env = old_env + self.current_remote_fn = old_remote_fn if not is_lambda: self.current_private_env = old_priv_env @@ -367,7 +373,8 @@ class ARTIQIRGenerator(algorithm.Visitor): return_value = self.visit(node.value) if self.return_target is None: - self.append(ir.Return(return_value)) + self.append(ir.Return(return_value, + remote_return=self.current_remote_fn)) else: self.append(ir.SetLocal(self.current_private_env, "$return", return_value)) self.append(ir.Branch(self.return_target)) @@ -2524,6 +2531,33 @@ class ARTIQIRGenerator(algorithm.Visitor): or types.is_builtin(typ, "at_mu"): return self.append(ir.Builtin(typ.name, [self.visit(arg) for arg in node.args], node.type)) + elif types.is_builtin(typ, "subkernel_await"): + if len(node.args) == 2 and len(node.keywords) == 0: + fn = node.args[0].type + timeout = self.visit(node.args[1]) + elif len(node.args) == 1 and len(node.keywords) == 0: + fn = node.args[0].type + timeout = ir.Constant(10_000, builtins.TInt64()) + else: + assert False + if types.is_method(fn): + fn = types.get_method_function(fn) + sid = ir.Constant(fn.sid, builtins.TInt32()) + if not builtins.is_none(fn.ret): + ret = self.append(ir.Builtin("subkernel_retrieve_return", [sid, timeout], fn.ret)) + else: + ret = ir.Constant(None, builtins.TNone()) + self.append(ir.Builtin("subkernel_await_finish", [sid, timeout], builtins.TNone())) + return ret + elif types.is_builtin(typ, "subkernel_preload"): + if len(node.args) == 1 and len(node.keywords) == 0: + fn = node.args[0].type + else: + assert False + if types.is_method(fn): + fn = types.get_method_function(fn) + sid = ir.Constant(fn.sid, builtins.TInt32()) + return self.append(ir.Builtin("subkernel_preload", [sid], builtins.TNone())) elif types.is_exn_constructor(typ): return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args]) elif types.is_constructor(typ): @@ -2535,8 +2569,8 @@ class ARTIQIRGenerator(algorithm.Visitor): node.loc) self.engine.process(diag) - def _user_call(self, callee, positional, keywords, arg_exprs={}): - if types.is_function(callee.type) or types.is_rpc(callee.type): + def _user_call(self, callee, positional, keywords, arg_exprs={}, remote_fn=False): + if types.is_function(callee.type) or types.is_rpc(callee.type) or types.is_subkernel(callee.type): func = callee self_arg = None fn_typ = callee.type @@ -2551,16 +2585,50 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - if types.is_rpc(fn_typ): - if self_arg is None: + if types.is_rpc(fn_typ) or types.is_subkernel(fn_typ): + if self_arg is None or types.is_subkernel(fn_typ): + # self is not passed to subkernels by remote args = positional - else: + elif self_arg is not None: args = [self_arg] + positional for keyword in keywords: arg = keywords[keyword] args.append(self.append(ir.Alloc([ir.Constant(keyword, builtins.TStr()), arg], ir.TKeyword(arg.type)))) + elif remote_fn: + assert self_arg is None + assert len(fn_typ.args) >= len(positional) + assert len(keywords) == 0 # no keyword support + args = [None] * fn_typ.arity() + index = 0 + # fill in first available args + for arg in positional: + args[index] = arg + index += 1 + + # remaining args are received through DRTIO + if index < len(args): + # min/max args received remotely (minus already filled) + offset = index + min_args = ir.Constant(len(fn_typ.args)-offset, builtins.TInt8()) + max_args = ir.Constant(fn_typ.arity()-offset, builtins.TInt8()) + + rcvd_count = self.append(ir.Builtin("subkernel_await_args", [min_args, max_args], builtins.TNone())) + arg_types = list(fn_typ.args.items())[offset:] + # obligatory arguments + for arg_name, arg_type in arg_types: + args[index] = self.append(ir.GetArgFromRemote(arg_name, arg_type, + name="ARG.{}".format(arg_name))) + index += 1 + + # optional arguments + for optarg_name, optarg_type in fn_typ.optargs.items(): + idx = ir.Constant(index-offset, builtins.TInt8()) + args[index] = \ + self.append(ir.GetOptArgFromRemote(optarg_name, optarg_type, rcvd_count, idx)) + index += 1 + else: args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) @@ -2646,7 +2714,8 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False, "Broadcasting for {} arguments not implemented".format(len) else: - insn = self._user_call(callee, args, keywords, node.arg_exprs) + remote_fn = getattr(node, "remote_fn", False) + insn = self._user_call(callee, args, keywords, node.arg_exprs, remote_fn) if isinstance(node.func, asttyped.AttributeT): attr_node = node.func self.method_map[(attr_node.value.type.find(), diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index 4c3112be6..07ab9ded2 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -238,7 +238,7 @@ class ASTTypedRewriter(algorithm.Transformer): body=node.body, decorator_list=node.decorator_list, keyword_loc=node.keyword_loc, name_loc=node.name_loc, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, - loc=node.loc) + loc=node.loc, remote_fn=False) try: self.env_stack.append(node.typing_env) @@ -439,8 +439,9 @@ class ASTTypedRewriter(algorithm.Transformer): def visit_Call(self, node): node = self.generic_visit(node) - node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={}, - func=node.func, args=node.args, keywords=node.keywords, + node = asttyped.CallT(type=types.TVar(), iodelay=None, arg_exprs={}, + remote_fn=False, func=node.func, + args=node.args, keywords=node.keywords, starargs=node.starargs, kwargs=node.kwargs, star_loc=node.star_loc, dstar_loc=node.dstar_loc, begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 57bbedf82..0b95a60e5 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -46,6 +46,7 @@ class Inferencer(algorithm.Visitor): self.function = None # currently visited function, for Return inference self.in_loop = False self.has_return = False + self.subkernel_arg_types = dict() def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""): try: @@ -178,7 +179,7 @@ class Inferencer(algorithm.Visitor): # Convert to a method. attr_type = types.TMethod(object_type, attr_type) self._unify_method_self(attr_type, attr_name, attr_loc, loc, value_node.loc) - elif types.is_rpc(attr_type): + elif types.is_rpc(attr_type) or types.is_subkernel(attr_type): # Convert to a method. We don't have to bother typechecking # the self argument, since for RPCs anything goes. attr_type = types.TMethod(object_type, attr_type) @@ -1293,6 +1294,55 @@ class Inferencer(algorithm.Visitor): # Ignored. self._unify(node.type, builtins.TNone(), node.loc, None) + elif types.is_builtin(typ, "subkernel_await"): + valid_forms = lambda: [ + valid_form("subkernel_await(f: subkernel) -> f return type"), + valid_form("subkernel_await(f: subkernel, timeout: numpy.int64) -> f return type") + ] + if 1 <= len(node.args) <= 2: + arg0 = node.args[0].type + if types.is_var(arg0): + pass # undetermined yet + else: + if types.is_method(arg0): + fn = types.get_method_function(arg0) + elif types.is_function(arg0) or types.is_subkernel(arg0): + fn = arg0 + else: + diagnose(valid_forms()) + self._unify(node.type, fn.ret, + node.loc, None) + if len(node.args) == 2: + arg1 = node.args[1] + if types.is_var(arg1.type): + pass + elif builtins.is_int(arg1.type): + # promote to TInt64 + self._unify(arg1.type, builtins.TInt64(), + arg1.loc, None) + else: + diagnose(valid_forms()) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "subkernel_preload"): + valid_forms = lambda: [ + valid_form("subkernel_preload(f: subkernel) -> None") + ] + if len(node.args) == 1: + arg0 = node.args[0].type + if types.is_var(arg0): + pass # undetermined yet + else: + if types.is_method(arg0): + fn = types.get_method_function(arg0) + elif types.is_function(arg0) or types.is_subkernel(arg0): + fn = arg0 + else: + diagnose(valid_forms()) + self._unify(node.type, fn.ret, + node.loc, None) + else: + diagnose(valid_forms()) else: assert False @@ -1331,6 +1381,7 @@ class Inferencer(algorithm.Visitor): typ_args = typ.args typ_optargs = typ.optargs typ_ret = typ.ret + typ_func = typ else: typ_self = types.get_method_self(typ) typ_func = types.get_method_function(typ) @@ -1388,12 +1439,23 @@ class Inferencer(algorithm.Visitor): other_node=node.args[0]) self._unify(node.type, ret, node.loc, None) return + if types.is_subkernel(typ_func) and typ_func.sid not in self.subkernel_arg_types: + self.subkernel_arg_types[typ_func.sid] = [] for actualarg, (formalname, formaltyp) in \ zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): self._unify(actualarg.type, formaltyp, actualarg.loc, None) passed_args[formalname] = actualarg.loc + if types.is_subkernel(typ_func): + if types.is_instance(actualarg.type): + # objects cannot be passed to subkernels, as rpc code doesn't support them + diag = diagnostic.Diagnostic("error", + "argument '{name}' of type: {typ} is not supported in subkernels", + {"name": formalname, "typ": actualarg.type}, + actualarg.loc, []) + self.engine.process(diag) + self.subkernel_arg_types[typ_func.sid].append((formalname, formaltyp)) for keyword in node.keywords: if keyword.arg in passed_args: @@ -1424,7 +1486,7 @@ class Inferencer(algorithm.Visitor): passed_args[keyword.arg] = keyword.arg_loc for formalname in typ_args: - if formalname not in passed_args: + if formalname not in passed_args and not node.remote_fn: note = diagnostic.Diagnostic("note", "the called function is of type {type}", {"type": types.TypePrinter().name(node.func.type)}, diff --git a/artiq/compiler/transforms/iodelay_estimator.py b/artiq/compiler/transforms/iodelay_estimator.py index 90bfefdb3..fcee126cf 100644 --- a/artiq/compiler/transforms/iodelay_estimator.py +++ b/artiq/compiler/transforms/iodelay_estimator.py @@ -280,7 +280,7 @@ class IODelayEstimator(algorithm.Visitor): context="as an argument for delay_mu()") call_delay = value elif not types.is_builtin(typ): - if types.is_function(typ) or types.is_rpc(typ): + if types.is_function(typ) or types.is_rpc(typ) or types.is_subkernel(typ): offset = 0 elif types.is_method(typ): offset = 1 @@ -288,7 +288,7 @@ class IODelayEstimator(algorithm.Visitor): else: assert False - if types.is_rpc(typ): + if types.is_rpc(typ) or types.is_subkernel(typ): call_delay = iodelay.Const(0) else: delay = typ.find().delay.find() @@ -311,13 +311,20 @@ class IODelayEstimator(algorithm.Visitor): args[arg_name] = arg_node free_vars = delay.duration.free_vars() - node.arg_exprs = { - arg: self.evaluate(args[arg], abort=abort, - context="in the expression for argument '{}' " - "that affects I/O delay".format(arg)) - for arg in free_vars - } - call_delay = delay.duration.fold(node.arg_exprs) + try: + node.arg_exprs = { + arg: self.evaluate(args[arg], abort=abort, + context="in the expression for argument '{}' " + "that affects I/O delay".format(arg)) + for arg in free_vars + } + call_delay = delay.duration.fold(node.arg_exprs) + except KeyError as e: + if getattr(node, "remote_fn", False): + note = diagnostic.Diagnostic("note", + "function called here", {}, + node.loc) + self.abort("due to arguments passed remotely", node.loc, note) else: assert False else: diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index e3a554cf3..88412a04c 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -215,7 +215,7 @@ class LLVMIRGenerator: typ = typ.find() if types.is_tuple(typ): return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) - elif types.is_rpc(typ) or types.is_external_function(typ): + elif types.is_rpc(typ) or types.is_external_function(typ) or types.is_subkernel(typ): if for_return: return llvoid else: @@ -398,6 +398,15 @@ class LLVMIRGenerator: elif name == "rpc_recv": llty = ll.FunctionType(lli32, [llptr]) + elif name == "subkernel_send_message": + llty = ll.FunctionType(llvoid, [lli32, lli8, llsliceptr, llptrptr]) + elif name == "subkernel_load_run": + llty = ll.FunctionType(llvoid, [lli32, lli1]) + elif name == "subkernel_await_finish": + llty = ll.FunctionType(llvoid, [lli32, lli64]) + elif name == "subkernel_await_message": + llty = ll.FunctionType(lli8, [lli32, lli64, lli8, lli8]) + # with now-pinning elif name == "now": llty = lli64 @@ -874,6 +883,53 @@ class LLVMIRGenerator: llvalue = self.llbuilder.bitcast(llvalue, llptr.type.pointee) return self.llbuilder.store(llvalue, llptr) + def process_GetArgFromRemote(self, insn): + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], + name="subkernel.arg.stack") + llval = self._build_rpc_recv(insn.arg_type, llstackptr) + return llval + + def process_GetOptArgFromRemote(self, insn): + # optarg = index < rcv_count ? Some(rcv_recv()) : None + llhead = self.llbuilder.basic_block + llrcv = self.llbuilder.append_basic_block(name="optarg.get.{}".format(insn.arg_name)) + + # argument received + self.llbuilder.position_at_end(llrcv) + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], + name="subkernel.arg.stack") + llval = self._build_rpc_recv(insn.arg_type, llstackptr) + llrpcretblock = self.llbuilder.basic_block # 'return' from rpc_recv, will be needed later + + # create the tail block, needs to be after the rpc recv tail block + lltail = self.llbuilder.append_basic_block(name="optarg.tail.{}".format(insn.arg_name)) + self.llbuilder.branch(lltail) + + # go back to head to add a branch to the tail + self.llbuilder.position_at_end(llhead) + llargrcvd = self.llbuilder.icmp_unsigned("<", self.map(insn.index), self.map(insn.rcv_count)) + self.llbuilder.cbranch(llargrcvd, llrcv, lltail) + + # argument not received/after arg recvd + self.llbuilder.position_at_end(lltail) + + llargtype = self.llty_of_type(insn.arg_type) + + llphi_arg_present = self.llbuilder.phi(lli1, name="optarg.phi.present.{}".format(insn.arg_name)) + llphi_arg = self.llbuilder.phi(llargtype, name="optarg.phi.{}".format(insn.arg_name)) + + llphi_arg_present.add_incoming(ll.Constant(lli1, 0), llhead) + llphi_arg.add_incoming(ll.Constant(llargtype, ll.Undefined), llhead) + + llphi_arg_present.add_incoming(ll.Constant(lli1, 1), llrpcretblock) + llphi_arg.add_incoming(llval, llrpcretblock) + + lloptarg = ll.Constant(ll.LiteralStructType([lli1, llargtype]), ll.Undefined) + lloptarg = self.llbuilder.insert_value(lloptarg, llphi_arg_present, 0) + lloptarg = self.llbuilder.insert_value(lloptarg, llphi_arg, 1) + + return lloptarg + def attr_index(self, typ, attr): return list(typ.attributes.keys()).index(attr) @@ -898,8 +954,8 @@ class LLVMIRGenerator: def get_global_closure_ptr(self, typ, attr): closure_type = typ.attributes[attr] assert types.is_constructor(typ) - assert types.is_function(closure_type) or types.is_rpc(closure_type) - if types.is_external_function(closure_type) or types.is_rpc(closure_type): + assert types.is_function(closure_type) or types.is_rpc(closure_type) or types.is_subkernel(closure_type) + if types.is_external_function(closure_type) or types.is_rpc(closure_type) or types.is_subkernel(closure_type): return None llty = self.llty_of_type(typ.attributes[attr]) @@ -1344,6 +1400,29 @@ class LLVMIRGenerator: return self.llbuilder.call(self.llbuiltin("delay_mu"), [llinterval]) elif insn.op == "end_catch": return self.llbuilder.call(self.llbuiltin("__artiq_end_catch"), []) + elif insn.op == "subkernel_await_args": + llmin = self.map(insn.operands[0]) + llmax = self.map(insn.operands[1]) + return self.llbuilder.call(self.llbuiltin("subkernel_await_message"), + [ll.Constant(lli32, 0), ll.Constant(lli64, 10_000), llmin, llmax], + name="subkernel.await.args") + elif insn.op == "subkernel_await_finish": + llsid = self.map(insn.operands[0]) + lltimeout = self.map(insn.operands[1]) + return self.llbuilder.call(self.llbuiltin("subkernel_await_finish"), [llsid, lltimeout], + name="subkernel.await.finish") + elif insn.op == "subkernel_retrieve_return": + llsid = self.map(insn.operands[0]) + lltimeout = self.map(insn.operands[1]) + self.llbuilder.call(self.llbuiltin("subkernel_await_message"), [llsid, lltimeout, ll.Constant(lli8, 1), ll.Constant(lli8, 1)], + name="subkernel.await.message") + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], + name="subkernel.arg.stack") + return self._build_rpc_recv(insn.type, llstackptr) + elif insn.op == "subkernel_preload": + llsid = self.map(insn.operands[0]) + return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, ll.Constant(lli1, 0)], + name="subkernel.preload") else: assert False @@ -1426,6 +1505,58 @@ class LLVMIRGenerator: return llfun, list(llargs), llarg_attrs, llcallstackptr + def _build_rpc_recv(self, ret, llstackptr, llnormalblock=None, llunwindblock=None): + # T result = { + # void *ret_ptr = alloca(sizeof(T)); + # void *ptr = ret_ptr; + # loop: int size = rpc_recv(ptr); + # // Non-zero: Provide `size` bytes of extra storage for variable-length data. + # if(size) { ptr = alloca(size); goto loop; } + # else *(T*)ret_ptr + # } + llprehead = self.llbuilder.basic_block + llhead = self.llbuilder.append_basic_block(name="rpc.head") + if llunwindblock: + llheadu = self.llbuilder.append_basic_block(name="rpc.head.unwind") + llalloc = self.llbuilder.append_basic_block(name="rpc.continue") + lltail = self.llbuilder.append_basic_block(name="rpc.tail") + + llretty = self.llty_of_type(ret) + llslot = self.llbuilder.alloca(llretty, name="rpc.ret.alloc") + llslotgen = self.llbuilder.bitcast(llslot, llptr, name="rpc.ret.ptr") + self.llbuilder.branch(llhead) + + self.llbuilder.position_at_end(llhead) + llphi = self.llbuilder.phi(llslotgen.type, name="rpc.ptr") + llphi.add_incoming(llslotgen, llprehead) + if llunwindblock: + llsize = self.llbuilder.invoke(self.llbuiltin("rpc_recv"), [llphi], + llheadu, llunwindblock, + name="rpc.size.next") + self.llbuilder.position_at_end(llheadu) + else: + llsize = self.llbuilder.call(self.llbuiltin("rpc_recv"), [llphi], + name="rpc.size.next") + lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0), + name="rpc.done") + self.llbuilder.cbranch(lldone, lltail, llalloc) + + self.llbuilder.position_at_end(llalloc) + llalloca = self.llbuilder.alloca(lli8, llsize, name="rpc.alloc") + llalloca.align = self.max_target_alignment + llphi.add_incoming(llalloca, llalloc) + self.llbuilder.branch(llhead) + + self.llbuilder.position_at_end(lltail) + llret = self.llbuilder.load(llslot, name="rpc.ret") + if not ret.fold(False, lambda r, t: r or builtins.is_allocated(t)): + # We didn't allocate anything except the slot for the value itself. + # Don't waste stack space. + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + if llnormalblock: + self.llbuilder.branch(llnormalblock) + return llret + def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): llservice = ll.Constant(lli32, fun_type.service) @@ -1501,57 +1632,103 @@ class LLVMIRGenerator: return ll.Undefined - # T result = { - # void *ret_ptr = alloca(sizeof(T)); - # void *ptr = ret_ptr; - # loop: int size = rpc_recv(ptr); - # // Non-zero: Provide `size` bytes of extra storage for variable-length data. - # if(size) { ptr = alloca(size); goto loop; } - # else *(T*)ret_ptr - # } - llprehead = self.llbuilder.basic_block - llhead = self.llbuilder.append_basic_block(name="rpc.head") - if llunwindblock: - llheadu = self.llbuilder.append_basic_block(name="rpc.head.unwind") - llalloc = self.llbuilder.append_basic_block(name="rpc.continue") - lltail = self.llbuilder.append_basic_block(name="rpc.tail") + llret = self._build_rpc_recv(fun_type.ret, llstackptr, llnormalblock, llunwindblock) - llretty = self.llty_of_type(fun_type.ret) - llslot = self.llbuilder.alloca(llretty, name="rpc.ret.alloc") - llslotgen = self.llbuilder.bitcast(llslot, llptr, name="rpc.ret.ptr") - self.llbuilder.branch(llhead) - - self.llbuilder.position_at_end(llhead) - llphi = self.llbuilder.phi(llslotgen.type, name="rpc.ptr") - llphi.add_incoming(llslotgen, llprehead) - if llunwindblock: - llsize = self.llbuilder.invoke(self.llbuiltin("rpc_recv"), [llphi], - llheadu, llunwindblock, - name="rpc.size.next") - self.llbuilder.position_at_end(llheadu) - else: - llsize = self.llbuilder.call(self.llbuiltin("rpc_recv"), [llphi], - name="rpc.size.next") - lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0), - name="rpc.done") - self.llbuilder.cbranch(lldone, lltail, llalloc) - - self.llbuilder.position_at_end(llalloc) - llalloca = self.llbuilder.alloca(lli8, llsize, name="rpc.alloc") - llalloca.align = self.max_target_alignment - llphi.add_incoming(llalloca, llalloc) - self.llbuilder.branch(llhead) - - self.llbuilder.position_at_end(lltail) - llret = self.llbuilder.load(llslot, name="rpc.ret") - if not fun_type.ret.fold(False, lambda r, t: r or builtins.is_allocated(t)): - # We didn't allocate anything except the slot for the value itself. - # Don't waste stack space. - self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) - if llnormalblock: - self.llbuilder.branch(llnormalblock) return llret + def _build_subkernel_call(self, fun_loc, fun_type, args): + llsid = ll.Constant(lli32, fun_type.sid) + tag = b"" + + for arg in args: + def arg_error_handler(typ): + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "value of type {type}", + {"type": printer.name(typ)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "type {type} is not supported in subkernel calls", + {"type": printer.name(arg.type)}, + arg.loc, notes=[note]) + self.engine.process(diag) + tag += ir.rpc_tag(arg.type, arg_error_handler) + tag += b":" + + # run the kernel first + self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, ll.Constant(lli1, 1)]) + + # arg sent in the same vein as RPC + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], + name="subkernel.stack") + + lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr())) + lltagptr = self.llbuilder.alloca(lltag.type) + self.llbuilder.store(lltag, lltagptr) + + if args: + # only send args if there's anything to send, 'self' is excluded + llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)), + name="subkernel.args") + for index, arg in enumerate(args): + if builtins.is_none(arg.type): + llargslot = self.llbuilder.alloca(llunit, + name="subkernel.arg{}".format(index)) + else: + llarg = self.map(arg) + llargslot = self.llbuilder.alloca(llarg.type, + name="subkernel.arg{}".format(index)) + self.llbuilder.store(llarg, llargslot) + llargslot = self.llbuilder.bitcast(llargslot, llptr) + + llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)]) + self.llbuilder.store(llargslot, llargptr) + + llargcount = ll.Constant(lli8, len(args)) + + self.llbuilder.call(self.llbuiltin("subkernel_send_message"), + [llsid, llargcount, lltagptr, llargs]) + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + + return llsid + + def _build_subkernel_return(self, insn): + # builds a remote return. + # unlike args, return only sends one thing. + if builtins.is_none(insn.value().type): + # do not waste time and bandwidth on Nones + return + + def ret_error_handler(typ): + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "value of type {type}", + {"type": printer.name(typ)}, + fun_loc) + diag = diagnostic.Diagnostic("error", + "return type {type} is not supported in subkernel returns", + {"type": printer.name(fun_type.ret)}, + fun_loc, notes=[note]) + self.engine.process(diag) + tag = ir.rpc_tag(insn.value().type, ret_error_handler) + tag += b":" + lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr())) + lltagptr = self.llbuilder.alloca(lltag.type) + self.llbuilder.store(lltag, lltagptr) + + llrets = self.llbuilder.alloca(llptr, ll.Constant(lli32, 1), + name="subkernel.return") + llret = self.map(insn.value()) + llretslot = self.llbuilder.alloca(llret.type, name="subkernel.retval") + self.llbuilder.store(llret, llretslot) + llretslot = self.llbuilder.bitcast(llretslot, llptr) + self.llbuilder.store(llretslot, llrets) + + llsid = ll.Constant(lli32, 0) # return goes back to master, sid is ignored + lltagcount = ll.Constant(lli8, 1) # only one thing is returned + self.llbuilder.call(self.llbuiltin("subkernel_send_message"), + [llsid, lltagcount, lltagptr, llrets]) + def process_Call(self, insn): functiontyp = insn.target_function().type if types.is_rpc(functiontyp): @@ -1559,6 +1736,10 @@ class LLVMIRGenerator: functiontyp, insn.arguments(), llnormalblock=None, llunwindblock=None) + elif types.is_subkernel(functiontyp): + return self._build_subkernel_call(insn.target_function().loc, + functiontyp, + insn.arguments()) elif types.is_external_function(functiontyp): llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn) else: @@ -1595,6 +1776,11 @@ class LLVMIRGenerator: functiontyp, insn.arguments(), llnormalblock, llunwindblock) + elif types.is_subkernel(functiontyp): + return self._build_subkernel_call(insn.target_function().loc, + functiontyp, + insn.arguments(), + llnormalblock, llunwindblock) elif types.is_external_function(functiontyp): llfun, llargs, llarg_attrs, llcallstackptr = self._prepare_ffi_call(insn) else: @@ -1673,7 +1859,8 @@ class LLVMIRGenerator: attrvalue = getattr(value, attr) is_class_function = (types.is_constructor(typ) and types.is_function(typ.attributes[attr]) and - not types.is_external_function(typ.attributes[attr])) + not types.is_external_function(typ.attributes[attr]) and + not types.is_subkernel(typ.attributes[attr])) if is_class_function: attrvalue = self.embedding_map.specialize_function(typ.instance, attrvalue) if not (types.is_instance(typ) and attr in typ.constant_attributes): @@ -1758,7 +1945,8 @@ class LLVMIRGenerator: llelts = [self._quote(v, t, lambda: path() + [str(i)]) for i, (v, t) in enumerate(zip(value, typ.elts))] return ll.Constant(llty, llelts) - elif types.is_rpc(typ) or types.is_external_function(typ) or types.is_builtin_function(typ): + elif types.is_rpc(typ) or types.is_external_function(typ) or \ + types.is_builtin_function(typ) or types.is_subkernel(typ): # RPC, C and builtin functions have no runtime representation. return ll.Constant(llty, ll.Undefined) elif types.is_function(typ): @@ -1813,6 +2001,8 @@ class LLVMIRGenerator: return llinsn def process_Return(self, insn): + if insn.remote_return: + self._build_subkernel_return(insn) if builtins.is_none(insn.value().type): return self.llbuilder.ret_void() else: diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 1d9336b4d..7f397d308 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -385,6 +385,50 @@ class TRPC(Type): def __hash__(self): return hash(self.service) +class TSubkernel(TFunction): + """ + A kernel to be run on a satellite. + + :ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`) + function arguments + :ivar ret: (:class:`Type`) + return type + :ivar sid: (int) subkernel ID number + :ivar destination: (int) satellite destination number + """ + + attributes = OrderedDict() + + def __init__(self, args, optargs, ret, sid, destination): + assert isinstance(ret, Type) + super().__init__(args, optargs, ret) + self.sid, self.destination = sid, destination + self.delay = TFixedDelay(iodelay.Const(0)) + + def unify(self, other): + if other is self: + return + if isinstance(other, TSubkernel) and \ + self.sid == other.sid and \ + self.destination == other.destination: + self.ret.unify(other.ret) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def __repr__(self): + if getattr(builtins, "__in_sphinx__", False): + return str(self) + return "artiq.compiler.types.TSubkernel({})".format(repr(self.ret)) + + def __eq__(self, other): + return isinstance(other, TSubkernel) and \ + self.sid == other.sid + + def __hash__(self): + return hash(self.sid) + class TBuiltin(Type): """ An instance of builtin type. Every instance of a builtin @@ -644,6 +688,9 @@ def is_function(typ): def is_rpc(typ): return isinstance(typ.find(), TRPC) +def is_subkernel(typ): + return isinstance(typ.find(), TSubkernel) + def is_external_function(typ, name=None): typ = typ.find() if name is None: @@ -810,6 +857,10 @@ class TypePrinter(object): return "[rpc{} #{}](...)->{}".format(typ.service, " async" if typ.is_async else "", self.name(typ.ret, depth + 1)) + elif isinstance(typ, TSubkernel): + return "->{}".format(typ.sid, + typ.destination, + self.name(typ.ret, depth + 1)) elif isinstance(typ, TBuiltinFunction): return "".format(typ.name) elif isinstance(typ, (TConstructor, TExceptionConstructor)): diff --git a/artiq/frontend/artiq_compile.py b/artiq/frontend/artiq_compile.py index fcba5297d..938f5b787 100755 --- a/artiq/frontend/artiq_compile.py +++ b/artiq/frontend/artiq_compile.py @@ -73,8 +73,8 @@ def main(): finally: dataset_db.close_db() - if object_map.has_rpc(): - raise ValueError("Experiment must not use RPC") + if object_map.has_rpc_or_subkernel(): + raise ValueError("Experiment must not use RPC or subkernels") output = args.output if output is None: diff --git a/artiq/language/core.py b/artiq/language/core.py index 5560398dd..2aff914a9 100644 --- a/artiq/language/core.py +++ b/artiq/language/core.py @@ -7,7 +7,7 @@ from functools import wraps import numpy -__all__ = ["kernel", "portable", "rpc", "syscall", "host_only", +__all__ = ["kernel", "portable", "rpc", "subkernel", "syscall", "host_only", "kernel_from_string", "set_time_manager", "set_watchdog_factory", "TerminationRequested"] @@ -21,7 +21,7 @@ __all__.extend(kernel_globals) _ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", - "core_name portable function syscall forbidden flags") + "core_name portable function syscall forbidden destination flags") def kernel(arg=None, flags={}): """ @@ -54,7 +54,7 @@ def kernel(arg=None, flags={}): return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs) run_on_core.artiq_embedded = _ARTIQEmbeddedInfo( core_name=arg, portable=False, function=function, syscall=None, - forbidden=False, flags=set(flags)) + forbidden=False, destination=None, flags=set(flags)) return run_on_core return inner_decorator elif arg is None: @@ -64,6 +64,50 @@ def kernel(arg=None, flags={}): else: return kernel("core", flags)(arg) +def subkernel(arg=None, destination=0, flags={}): + """ + This decorator marks an object's method or function for execution on a satellite device. + Destination must be given, and it must be between 1 and 255 (inclusive). + + Subkernels behave similarly to kernels, with few key differences: + + - they are started from main kernels, + - they do not support RPCs, or running subsequent subkernels on other devices, + - but they can call other kernels or subkernels with the same destination. + + Subkernels can accept arguments and return values. However, they must be fully + annotated with ARTIQ types. + + To call a subkernel, call it like a normal function. + + To await its finishing execution, call ``subkernel_await(subkernel, [timeout])``. + The timeout parameter is optional, and by default is equal to 10000 (miliseconds). + This time can be adjusted for subkernels that take a long time to execute. + + The compiled subkernel is copied to satellites, but not yet to the kernel core + until it's called. For bigger subkernels it may take some time before they + actually start running. To help with that, subkernels can be preloaded, with + ``subkernel_preload(subkernel)`` function. A call to a preloaded subkernel + will take less time, but only one subkernel can be preloaded at a time. + """ + if isinstance(arg, str): + def inner_decorator(function): + @wraps(function) + def run_subkernel(self, *k_args, **k_kwargs): + sid = getattr(self, arg).prepare_subkernel(destination, run_subkernel, ((self,) + k_args), k_kwargs) + getattr(self, arg).run_subkernel(sid) + run_subkernel.artiq_embedded = _ARTIQEmbeddedInfo( + core_name=arg, portable=False, function=function, syscall=None, + forbidden=False, destination=destination, flags=set(flags)) + return run_subkernel + return inner_decorator + elif arg is None: + def inner_decorator(function): + return subkernel(function, destination, flags) + return inner_decorator + else: + return subkernel("core", destination, flags)(arg) + def portable(arg=None, flags={}): """ This decorator marks a function for execution on the same device as its @@ -84,7 +128,7 @@ def portable(arg=None, flags={}): else: arg.artiq_embedded = \ _ARTIQEmbeddedInfo(core_name=None, portable=True, function=arg, syscall=None, - forbidden=False, flags=set(flags)) + forbidden=False, destination=None, flags=set(flags)) return arg def rpc(arg=None, flags={}): @@ -100,7 +144,7 @@ def rpc(arg=None, flags={}): else: arg.artiq_embedded = \ _ARTIQEmbeddedInfo(core_name=None, portable=False, function=arg, syscall=None, - forbidden=False, flags=set(flags)) + forbidden=False, destination=None, flags=set(flags)) return arg def syscall(arg=None, flags={}): @@ -118,7 +162,7 @@ def syscall(arg=None, flags={}): def inner_decorator(function): function.artiq_embedded = \ _ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, - syscall=arg, forbidden=False, + syscall=arg, forbidden=False, destination=None, flags=set(flags)) return function return inner_decorator @@ -136,7 +180,7 @@ def host_only(function): """ function.artiq_embedded = \ _ARTIQEmbeddedInfo(core_name=None, portable=False, function=None, syscall=None, - forbidden=True, flags={}) + forbidden=True, destination=None, flags={}) return function