From 0ba0330b53ce13f2c724286b36c9c9ecd7f51d52 Mon Sep 17 00:00:00 2001 From: mwojcik Date: Fri, 26 Jan 2024 16:02:28 +0800 Subject: [PATCH] compiler: support free subkernel message passing --- artiq/compiler/builtins.py | 6 + artiq/compiler/embedding.py | 48 ++++++-- artiq/compiler/prelude.py | 2 + .../compiler/transforms/artiq_ir_generator.py | 36 ++++++ artiq/compiler/transforms/inferencer.py | 51 +++++++++ .../compiler/transforms/llvm_ir_generator.py | 103 +++++++++--------- artiq/coredevice/core.py | 6 +- 7 files changed, 189 insertions(+), 63 deletions(-) diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 64e9b3690..cb0834f71 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -253,6 +253,12 @@ def fn_subkernel_await(): def fn_subkernel_preload(): return types.TBuiltinFunction("subkernel_preload") +def fn_subkernel_send(): + return types.TBuiltinFunction("subkernel_send") + +def fn_subkernel_recv(): + return types.TBuiltinFunction("subkernel_recv") + # Accessors def is_none(typ): diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 928e9e75d..c46c69da3 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -47,8 +47,13 @@ class SpecializedFunction: return hash((self.instance_type, self.host_function)) +class SubkernelMessageType: + def __init__(self, name, value_type): + self.name = name + self.value_type = value_type + class EmbeddingMap: - def __init__(self, subkernels={}): + def __init__(self, old_embedding_map=None): self.object_current_key = 0 self.object_forward_map = {} self.object_reverse_map = {} @@ -64,13 +69,22 @@ class EmbeddingMap: self.function_map = {} self.str_forward_map = {} self.str_reverse_map = {} - + + # mapping `name` to object ID + self.subkernel_message_map = {} + # subkernels: dict of ID: function, just like object_forward_map # allow the embedding map to be aware of subkernels from other kernels - for key, obj_ref in subkernels.items(): - self.object_forward_map[key] = obj_ref - obj_id = id(obj_ref) - self.object_reverse_map[obj_id] = key + if not old_embedding_map is None: + for key, obj_ref in old_embedding_map.subkernels().items(): + self.object_forward_map[key] = obj_ref + obj_id = id(obj_ref) + self.object_reverse_map[obj_id] = key + for msg_id, msg_type in old_embedding_map.subkernel_messages().items(): + self.object_forward_map[msg_id] = msg_type + obj_id = id(msg_type) + self.subkernel_message_map[msg_type.name] = msg_id + self.object_reverse_map[obj_id] = msg_id self.preallocate_runtime_exception_names(["RuntimeError", "RTIOUnderflow", @@ -174,7 +188,7 @@ class EmbeddingMap: self.object_current_key += 1 while self.object_forward_map.get(self.object_current_key): # make sure there's no collisions with previously inserted subkernels - # their identifiers must be consistent between kernels/subkernels + # their identifiers must be consistent across all kernels/subkernels self.object_current_key += 1 self.object_forward_map[self.object_current_key] = obj_ref @@ -189,7 +203,7 @@ class EmbeddingMap: obj_ref = self.object_forward_map[obj_id] if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType, pytypes.BuiltinFunctionType, pytypes.ModuleType, - SpecializedFunction)): + SpecializedFunction, SubkernelMessageType)): continue elif isinstance(obj_ref, type): _, obj_typ = self.type_map[obj_ref] @@ -205,6 +219,20 @@ class EmbeddingMap: subkernels[k] = v return subkernels + def store_subkernel_message(self, name, value_type): + if name in self.subkernel_message_map: + msg_id = self.subkernel_message_map[name] + else: + msg_id = self.store_object(SubkernelMessageType(name, value_type)) + self.subkernel_message_map[name] = msg_id + return msg_id, self.retrieve_object(msg_id) + + def subkernel_messages(self): + messages = {} + for name, msg_id in self.subkernel_message_map.items(): + messages[msg_id] = self.retrieve_object(msg_id) + return messages + def has_rpc(self): return any(filter( lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) and \ @@ -802,7 +830,7 @@ class TypedtreeHasher(algorithm.Visitor): return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields)) class Stitcher: - def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[], subkernels={}): + def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[], old_embedding_map=None): self.core = core self.dmgr = dmgr if engine is None: @@ -824,7 +852,7 @@ class Stitcher: self.functions = {} - self.embedding_map = EmbeddingMap(subkernels) + self.embedding_map = EmbeddingMap(old_embedding_map) self.value_map = defaultdict(lambda: []) self.definitely_changed = False diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index effbca87c..f96b4d0d7 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -59,4 +59,6 @@ def globals(): # ARTIQ subkernel utility functions "subkernel_await": builtins.fn_subkernel_await(), "subkernel_preload": builtins.fn_subkernel_preload(), + "subkernel_send": builtins.fn_subkernel_send(), + "subkernel_recv": builtins.fn_subkernel_recv(), } diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index 6998e0ddc..fe084caab 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -2559,6 +2559,42 @@ class ARTIQIRGenerator(algorithm.Visitor): sid = ir.Constant(fn.sid, builtins.TInt32()) dest = ir.Constant(fn.destination, builtins.TInt32()) return self.append(ir.Builtin("subkernel_preload", [sid, dest], builtins.TNone())) + elif types.is_builtin(typ, "subkernel_send"): + if len(node.args) == 3 and len(node.keywords) == 0: + dest = self.visit(node.args[0]) + name = node.args[1].s + value = self.visit(node.args[2]) + else: + assert False + msg_id, msg = self.embedding_map.store_subkernel_message(name, value.type) + msg_id = ir.Constant(msg_id, builtins.TInt32()) + if value.type != msg.value_type: + diag = diagnostic.Diagnostic("error", + "type mismatch for subkernel message '{name}', receiver expects {recv} while sending {send}", + {"name": name, "recv": msg.value_type, "send": value.type}, + node.loc) + self.engine.process(diag) + return self.append(ir.Builtin("subkernel_send", [msg_id, dest, value], builtins.TNone())) + elif types.is_builtin(typ, "subkernel_recv"): + if len(node.args) == 2 and len(node.keywords) == 0: + name = node.args[0].s + vartype = node.args[1].value + timeout = ir.Constant(10_000, builtins.TInt64()) + elif len(node.args) == 3 and len(node.keywords) == 0: + name = node.args[0].s + vartype = node.args[1].value + timeout = self.visit(node.args[2]) + else: + assert False + msg_id, msg = self.embedding_map.store_subkernel_message(name, vartype) + msg_id = ir.Constant(msg_id, builtins.TInt32()) + if vartype != msg.value_type: + diag = diagnostic.Diagnostic("error", + "type mismatch for subkernel message '{name}', receiver expects {recv} while sending {send}", + {"name": name, "recv": vartype, "send": msg.value_type}, + node.loc) + self.engine.process(diag) + return self.append(ir.Builtin("subkernel_recv", [msg_id, timeout], vartype)) elif types.is_exn_constructor(typ): return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args]) elif types.is_constructor(typ): diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 0b95a60e5..b94985463 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -1343,6 +1343,57 @@ class Inferencer(algorithm.Visitor): node.loc, None) else: diagnose(valid_forms()) + elif types.is_builtin(typ, "subkernel_send"): + valid_forms = lambda: [ + valid_form("subkernel_send(dest: numpy.int?, name: str, value: V) -> None"), + ] + self._unify(node.type, builtins.TNone(), + node.loc, None) + if len(node.args) == 3: + arg0 = node.args[0] + if types.is_var(arg0.type): + pass # undetermined yet + else: + if builtins.is_int(arg0.type): + self._unify(arg0.type, builtins.TInt8(), + arg0.loc, None) + else: + diagnose(valid_forms()) + arg1 = node.args[1] + self._unify(arg1.type, builtins.TStr(), + arg1.loc, None) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "subkernel_recv"): + valid_forms = lambda: [ + valid_form("subkernel_recv(name: str, value_type: type) -> value_type"), + valid_form("subkernel_recv(name: str, value_type: type, timeout: numpy.int64) -> value_type"), + ] + if 2 <= len(node.args) <= 3: + arg0 = node.args[0] + if types.is_var(arg0.type): + pass + else: + self._unify(arg0.type, builtins.TStr(), + arg0.loc, None) + arg1 = node.args[1] + if types.is_var(arg1.type): + pass + else: + self._unify(node.type, arg1.value, + node.loc, None) + if len(node.args) == 3: + arg2 = node.args[2] + if types.is_var(arg2.type): + pass + elif builtins.is_int(arg2.type): + # promote to TInt64 + self._unify(arg2.type, builtins.TInt64(), + arg2.loc, None) + else: + diagnose(valid_forms()) + else: + diagnose(valid_forms()) else: assert False diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 3b4e165f3..4f68d27a2 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -1420,6 +1420,20 @@ class LLVMIRGenerator: lldest = ll.Constant(lli8, insn.operands[1].value) return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 0)], name="subkernel.preload") + elif insn.op == "subkernel_send": + llmsgid = self.map(insn.operands[0]) + lldest = self.map(insn.operands[1]) + return self._build_subkernel_message(llmsgid, lldest, [insn.operands[2]]) + elif insn.op == "subkernel_recv": + llmsgid = self.map(insn.operands[0]) + lltimeout = self.map(insn.operands[1]) + lltagptr = self._build_subkernel_tags([insn.type]) + self.llbuilder.call(self.llbuiltin("subkernel_await_message"), + [llmsgid, lltimeout, lltagptr, ll.Constant(lli8, 1), ll.Constant(lli8, 1)], + name="subkernel.await.message") + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], + name="subkernel.arg.stack") + return self._build_rpc_recv(insn.type, llstackptr) else: assert False @@ -1580,11 +1594,8 @@ class LLVMIRGenerator: self.llbuilder.branch(llnormalblock) return llret - def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): - llservice = ll.Constant(lli32, fun_type.service) - + def _build_arg_tag(self, args, call_type): tag = b"" - for arg in args: def arg_error_handler(typ): printer = types.TypePrinter() @@ -1593,12 +1604,18 @@ class LLVMIRGenerator: {"type": printer.name(typ)}, arg.loc) diag = diagnostic.Diagnostic("error", - "type {type} is not supported in remote procedure calls", - {"type": printer.name(arg.type)}, + "type {type} is not supported in {call_type} calls", + {"type": printer.name(arg.type), "call_type": call_type}, arg.loc, notes=[note]) self.engine.process(diag) tag += ir.rpc_tag(arg.type, arg_error_handler) tag += b":" + return tag + + def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): + llservice = ll.Constant(lli32, fun_type.service) + + tag = self._build_arg_tag(args, call_type="remote procedure") def ret_error_handler(typ): printer = types.TypePrinter() @@ -1662,61 +1679,47 @@ class LLVMIRGenerator: def _build_subkernel_call(self, fun_loc, fun_type, args): llsid = ll.Constant(lli32, fun_type.sid) lldest = ll.Constant(lli8, fun_type.destination) - tag = b"" - - for arg in args: - def arg_error_handler(typ): - printer = types.TypePrinter() - note = diagnostic.Diagnostic("note", - "value of type {type}", - {"type": printer.name(typ)}, - arg.loc) - diag = diagnostic.Diagnostic("error", - "type {type} is not supported in subkernel calls", - {"type": printer.name(arg.type)}, - arg.loc, notes=[note]) - self.engine.process(diag) - tag += ir.rpc_tag(arg.type, arg_error_handler) - tag += b":" - # run the kernel first self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 1)]) - # arg sent in the same vein as RPC - llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], - name="subkernel.stack") + if args: + # only send args if there's anything to send, 'self' is excluded + self._build_subkernel_message(llsid, lldest, args) + return llsid + + def _build_subkernel_message(self, llid, lldest, args): + # args (or messages) are sent in the same vein as RPC + tag = self._build_arg_tag(args, call_type="subkernel") + + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], + name="subkernel.stack") lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr())) lltagptr = self.llbuilder.alloca(lltag.type) self.llbuilder.store(lltag, lltagptr) - if args: - # only send args if there's anything to send, 'self' is excluded - llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)), - name="subkernel.args") - for index, arg in enumerate(args): - if builtins.is_none(arg.type): - llargslot = self.llbuilder.alloca(llunit, - name="subkernel.arg{}".format(index)) - else: - llarg = self.map(arg) - llargslot = self.llbuilder.alloca(llarg.type, - name="subkernel.arg{}".format(index)) - self.llbuilder.store(llarg, llargslot) - llargslot = self.llbuilder.bitcast(llargslot, llptr) + llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)), + name="subkernel.args") + for index, arg in enumerate(args): + if builtins.is_none(arg.type): + llargslot = self.llbuilder.alloca(llunit, + name="subkernel.arg{}".format(index)) + else: + llarg = self.map(arg) + llargslot = self.llbuilder.alloca(llarg.type, + name="subkernel.arg{}".format(index)) + self.llbuilder.store(llarg, llargslot) + llargslot = self.llbuilder.bitcast(llargslot, llptr) - llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)]) - self.llbuilder.store(llargslot, llargptr) + llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)]) + self.llbuilder.store(llargslot, llargptr) - llargcount = ll.Constant(lli8, len(args)) + llargcount = ll.Constant(lli8, len(args)) - llisreturn = ll.Constant(lli1, False) - - self.llbuilder.call(self.llbuiltin("subkernel_send_message"), - [llsid, llisreturn, lldest, llargcount, lltagptr, llargs]) - self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) - - return llsid + llisreturn = ll.Constant(lli1, False) + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + return self.llbuilder.call(self.llbuiltin("subkernel_send_message"), + [llid, llisreturn, lldest, llargcount, lltagptr, llargs]) def _build_subkernel_return(self, insn): # builds a remote return. diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index 26d60e92e..4d3ed36b5 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -121,14 +121,14 @@ class Core: def compile(self, function, args, kwargs, set_result=None, attribute_writeback=True, print_as_rpc=True, target=None, destination=0, subkernel_arg_types=[], - subkernels={}): + old_embedding_map=None): try: engine = _DiagnosticEngine(all_errors_are_fatal=True) stitcher = Stitcher(engine=engine, core=self, dmgr=self.dmgr, print_as_rpc=print_as_rpc, destination=destination, subkernel_arg_types=subkernel_arg_types, - subkernels=subkernels) + old_embedding_map=old_embedding_map) stitcher.stitch_call(function, args, kwargs, set_result) stitcher.finalize() @@ -182,7 +182,7 @@ class Core: self.compile(subkernel_fn, self_arg, {}, attribute_writeback=False, print_as_rpc=False, target=target, destination=destination, subkernel_arg_types=subkernel_arg_types.get(sid, []), - subkernels=subkernels) + old_embedding_map=embedding_map) if object_map.has_rpc(): raise ValueError("Subkernel must not use RPC") return destination, kernel_library, object_map