compiler: support free subkernel message passing

This commit is contained in:
mwojcik 2024-01-26 16:02:28 +08:00 committed by Sébastien Bourdeauducq
parent 7d3bcc7cac
commit 0ba0330b53
7 changed files with 189 additions and 63 deletions

View File

@ -253,6 +253,12 @@ def fn_subkernel_await():
def fn_subkernel_preload(): def fn_subkernel_preload():
return types.TBuiltinFunction("subkernel_preload") return types.TBuiltinFunction("subkernel_preload")
def fn_subkernel_send():
return types.TBuiltinFunction("subkernel_send")
def fn_subkernel_recv():
return types.TBuiltinFunction("subkernel_recv")
# Accessors # Accessors
def is_none(typ): def is_none(typ):

View File

@ -47,8 +47,13 @@ class SpecializedFunction:
return hash((self.instance_type, self.host_function)) return hash((self.instance_type, self.host_function))
class SubkernelMessageType:
def __init__(self, name, value_type):
self.name = name
self.value_type = value_type
class EmbeddingMap: class EmbeddingMap:
def __init__(self, subkernels={}): def __init__(self, old_embedding_map=None):
self.object_current_key = 0 self.object_current_key = 0
self.object_forward_map = {} self.object_forward_map = {}
self.object_reverse_map = {} self.object_reverse_map = {}
@ -64,13 +69,22 @@ class EmbeddingMap:
self.function_map = {} self.function_map = {}
self.str_forward_map = {} self.str_forward_map = {}
self.str_reverse_map = {} self.str_reverse_map = {}
# mapping `name` to object ID
self.subkernel_message_map = {}
# subkernels: dict of ID: function, just like object_forward_map # subkernels: dict of ID: function, just like object_forward_map
# allow the embedding map to be aware of subkernels from other kernels # allow the embedding map to be aware of subkernels from other kernels
for key, obj_ref in subkernels.items(): if not old_embedding_map is None:
self.object_forward_map[key] = obj_ref for key, obj_ref in old_embedding_map.subkernels().items():
obj_id = id(obj_ref) self.object_forward_map[key] = obj_ref
self.object_reverse_map[obj_id] = key obj_id = id(obj_ref)
self.object_reverse_map[obj_id] = key
for msg_id, msg_type in old_embedding_map.subkernel_messages().items():
self.object_forward_map[msg_id] = msg_type
obj_id = id(msg_type)
self.subkernel_message_map[msg_type.name] = msg_id
self.object_reverse_map[obj_id] = msg_id
self.preallocate_runtime_exception_names(["RuntimeError", self.preallocate_runtime_exception_names(["RuntimeError",
"RTIOUnderflow", "RTIOUnderflow",
@ -174,7 +188,7 @@ class EmbeddingMap:
self.object_current_key += 1 self.object_current_key += 1
while self.object_forward_map.get(self.object_current_key): while self.object_forward_map.get(self.object_current_key):
# make sure there's no collisions with previously inserted subkernels # make sure there's no collisions with previously inserted subkernels
# their identifiers must be consistent between kernels/subkernels # their identifiers must be consistent across all kernels/subkernels
self.object_current_key += 1 self.object_current_key += 1
self.object_forward_map[self.object_current_key] = obj_ref self.object_forward_map[self.object_current_key] = obj_ref
@ -189,7 +203,7 @@ class EmbeddingMap:
obj_ref = self.object_forward_map[obj_id] obj_ref = self.object_forward_map[obj_id]
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType, if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
pytypes.BuiltinFunctionType, pytypes.ModuleType, pytypes.BuiltinFunctionType, pytypes.ModuleType,
SpecializedFunction)): SpecializedFunction, SubkernelMessageType)):
continue continue
elif isinstance(obj_ref, type): elif isinstance(obj_ref, type):
_, obj_typ = self.type_map[obj_ref] _, obj_typ = self.type_map[obj_ref]
@ -205,6 +219,20 @@ class EmbeddingMap:
subkernels[k] = v subkernels[k] = v
return subkernels return subkernels
def store_subkernel_message(self, name, value_type):
if name in self.subkernel_message_map:
msg_id = self.subkernel_message_map[name]
else:
msg_id = self.store_object(SubkernelMessageType(name, value_type))
self.subkernel_message_map[name] = msg_id
return msg_id, self.retrieve_object(msg_id)
def subkernel_messages(self):
messages = {}
for name, msg_id in self.subkernel_message_map.items():
messages[msg_id] = self.retrieve_object(msg_id)
return messages
def has_rpc(self): def has_rpc(self):
return any(filter( return any(filter(
lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) and \ lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) and \
@ -802,7 +830,7 @@ class TypedtreeHasher(algorithm.Visitor):
return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields)) return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields))
class Stitcher: class Stitcher:
def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[], subkernels={}): def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[], old_embedding_map=None):
self.core = core self.core = core
self.dmgr = dmgr self.dmgr = dmgr
if engine is None: if engine is None:
@ -824,7 +852,7 @@ class Stitcher:
self.functions = {} self.functions = {}
self.embedding_map = EmbeddingMap(subkernels) self.embedding_map = EmbeddingMap(old_embedding_map)
self.value_map = defaultdict(lambda: []) self.value_map = defaultdict(lambda: [])
self.definitely_changed = False self.definitely_changed = False

View File

@ -59,4 +59,6 @@ def globals():
# ARTIQ subkernel utility functions # ARTIQ subkernel utility functions
"subkernel_await": builtins.fn_subkernel_await(), "subkernel_await": builtins.fn_subkernel_await(),
"subkernel_preload": builtins.fn_subkernel_preload(), "subkernel_preload": builtins.fn_subkernel_preload(),
"subkernel_send": builtins.fn_subkernel_send(),
"subkernel_recv": builtins.fn_subkernel_recv(),
} }

View File

@ -2559,6 +2559,42 @@ class ARTIQIRGenerator(algorithm.Visitor):
sid = ir.Constant(fn.sid, builtins.TInt32()) sid = ir.Constant(fn.sid, builtins.TInt32())
dest = ir.Constant(fn.destination, builtins.TInt32()) dest = ir.Constant(fn.destination, builtins.TInt32())
return self.append(ir.Builtin("subkernel_preload", [sid, dest], builtins.TNone())) return self.append(ir.Builtin("subkernel_preload", [sid, dest], builtins.TNone()))
elif types.is_builtin(typ, "subkernel_send"):
if len(node.args) == 3 and len(node.keywords) == 0:
dest = self.visit(node.args[0])
name = node.args[1].s
value = self.visit(node.args[2])
else:
assert False
msg_id, msg = self.embedding_map.store_subkernel_message(name, value.type)
msg_id = ir.Constant(msg_id, builtins.TInt32())
if value.type != msg.value_type:
diag = diagnostic.Diagnostic("error",
"type mismatch for subkernel message '{name}', receiver expects {recv} while sending {send}",
{"name": name, "recv": msg.value_type, "send": value.type},
node.loc)
self.engine.process(diag)
return self.append(ir.Builtin("subkernel_send", [msg_id, dest, value], builtins.TNone()))
elif types.is_builtin(typ, "subkernel_recv"):
if len(node.args) == 2 and len(node.keywords) == 0:
name = node.args[0].s
vartype = node.args[1].value
timeout = ir.Constant(10_000, builtins.TInt64())
elif len(node.args) == 3 and len(node.keywords) == 0:
name = node.args[0].s
vartype = node.args[1].value
timeout = self.visit(node.args[2])
else:
assert False
msg_id, msg = self.embedding_map.store_subkernel_message(name, vartype)
msg_id = ir.Constant(msg_id, builtins.TInt32())
if vartype != msg.value_type:
diag = diagnostic.Diagnostic("error",
"type mismatch for subkernel message '{name}', receiver expects {recv} while sending {send}",
{"name": name, "recv": vartype, "send": msg.value_type},
node.loc)
self.engine.process(diag)
return self.append(ir.Builtin("subkernel_recv", [msg_id, timeout], vartype))
elif types.is_exn_constructor(typ): elif types.is_exn_constructor(typ):
return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args]) return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args])
elif types.is_constructor(typ): elif types.is_constructor(typ):

View File

@ -1343,6 +1343,57 @@ class Inferencer(algorithm.Visitor):
node.loc, None) node.loc, None)
else: else:
diagnose(valid_forms()) diagnose(valid_forms())
elif types.is_builtin(typ, "subkernel_send"):
valid_forms = lambda: [
valid_form("subkernel_send(dest: numpy.int?, name: str, value: V) -> None"),
]
self._unify(node.type, builtins.TNone(),
node.loc, None)
if len(node.args) == 3:
arg0 = node.args[0]
if types.is_var(arg0.type):
pass # undetermined yet
else:
if builtins.is_int(arg0.type):
self._unify(arg0.type, builtins.TInt8(),
arg0.loc, None)
else:
diagnose(valid_forms())
arg1 = node.args[1]
self._unify(arg1.type, builtins.TStr(),
arg1.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "subkernel_recv"):
valid_forms = lambda: [
valid_form("subkernel_recv(name: str, value_type: type) -> value_type"),
valid_form("subkernel_recv(name: str, value_type: type, timeout: numpy.int64) -> value_type"),
]
if 2 <= len(node.args) <= 3:
arg0 = node.args[0]
if types.is_var(arg0.type):
pass
else:
self._unify(arg0.type, builtins.TStr(),
arg0.loc, None)
arg1 = node.args[1]
if types.is_var(arg1.type):
pass
else:
self._unify(node.type, arg1.value,
node.loc, None)
if len(node.args) == 3:
arg2 = node.args[2]
if types.is_var(arg2.type):
pass
elif builtins.is_int(arg2.type):
# promote to TInt64
self._unify(arg2.type, builtins.TInt64(),
arg2.loc, None)
else:
diagnose(valid_forms())
else:
diagnose(valid_forms())
else: else:
assert False assert False

View File

@ -1420,6 +1420,20 @@ class LLVMIRGenerator:
lldest = ll.Constant(lli8, insn.operands[1].value) lldest = ll.Constant(lli8, insn.operands[1].value)
return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 0)], return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 0)],
name="subkernel.preload") name="subkernel.preload")
elif insn.op == "subkernel_send":
llmsgid = self.map(insn.operands[0])
lldest = self.map(insn.operands[1])
return self._build_subkernel_message(llmsgid, lldest, [insn.operands[2]])
elif insn.op == "subkernel_recv":
llmsgid = self.map(insn.operands[0])
lltimeout = self.map(insn.operands[1])
lltagptr = self._build_subkernel_tags([insn.type])
self.llbuilder.call(self.llbuiltin("subkernel_await_message"),
[llmsgid, lltimeout, lltagptr, ll.Constant(lli8, 1), ll.Constant(lli8, 1)],
name="subkernel.await.message")
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
name="subkernel.arg.stack")
return self._build_rpc_recv(insn.type, llstackptr)
else: else:
assert False assert False
@ -1580,11 +1594,8 @@ class LLVMIRGenerator:
self.llbuilder.branch(llnormalblock) self.llbuilder.branch(llnormalblock)
return llret return llret
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): def _build_arg_tag(self, args, call_type):
llservice = ll.Constant(lli32, fun_type.service)
tag = b"" tag = b""
for arg in args: for arg in args:
def arg_error_handler(typ): def arg_error_handler(typ):
printer = types.TypePrinter() printer = types.TypePrinter()
@ -1593,12 +1604,18 @@ class LLVMIRGenerator:
{"type": printer.name(typ)}, {"type": printer.name(typ)},
arg.loc) arg.loc)
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls", "type {type} is not supported in {call_type} calls",
{"type": printer.name(arg.type)}, {"type": printer.name(arg.type), "call_type": call_type},
arg.loc, notes=[note]) arg.loc, notes=[note])
self.engine.process(diag) self.engine.process(diag)
tag += ir.rpc_tag(arg.type, arg_error_handler) tag += ir.rpc_tag(arg.type, arg_error_handler)
tag += b":" tag += b":"
return tag
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
llservice = ll.Constant(lli32, fun_type.service)
tag = self._build_arg_tag(args, call_type="remote procedure")
def ret_error_handler(typ): def ret_error_handler(typ):
printer = types.TypePrinter() printer = types.TypePrinter()
@ -1662,61 +1679,47 @@ class LLVMIRGenerator:
def _build_subkernel_call(self, fun_loc, fun_type, args): def _build_subkernel_call(self, fun_loc, fun_type, args):
llsid = ll.Constant(lli32, fun_type.sid) llsid = ll.Constant(lli32, fun_type.sid)
lldest = ll.Constant(lli8, fun_type.destination) lldest = ll.Constant(lli8, fun_type.destination)
tag = b""
for arg in args:
def arg_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(typ)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in subkernel calls",
{"type": printer.name(arg.type)},
arg.loc, notes=[note])
self.engine.process(diag)
tag += ir.rpc_tag(arg.type, arg_error_handler)
tag += b":"
# run the kernel first # run the kernel first
self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 1)]) self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 1)])
# arg sent in the same vein as RPC if args:
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [], # only send args if there's anything to send, 'self' is excluded
name="subkernel.stack") self._build_subkernel_message(llsid, lldest, args)
return llsid
def _build_subkernel_message(self, llid, lldest, args):
# args (or messages) are sent in the same vein as RPC
tag = self._build_arg_tag(args, call_type="subkernel")
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
name="subkernel.stack")
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr())) lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
lltagptr = self.llbuilder.alloca(lltag.type) lltagptr = self.llbuilder.alloca(lltag.type)
self.llbuilder.store(lltag, lltagptr) self.llbuilder.store(lltag, lltagptr)
if args: llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)),
# only send args if there's anything to send, 'self' is excluded name="subkernel.args")
llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)), for index, arg in enumerate(args):
name="subkernel.args") if builtins.is_none(arg.type):
for index, arg in enumerate(args): llargslot = self.llbuilder.alloca(llunit,
if builtins.is_none(arg.type): name="subkernel.arg{}".format(index))
llargslot = self.llbuilder.alloca(llunit, else:
name="subkernel.arg{}".format(index)) llarg = self.map(arg)
else: llargslot = self.llbuilder.alloca(llarg.type,
llarg = self.map(arg) name="subkernel.arg{}".format(index))
llargslot = self.llbuilder.alloca(llarg.type, self.llbuilder.store(llarg, llargslot)
name="subkernel.arg{}".format(index)) llargslot = self.llbuilder.bitcast(llargslot, llptr)
self.llbuilder.store(llarg, llargslot)
llargslot = self.llbuilder.bitcast(llargslot, llptr)
llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)]) llargptr = self.llbuilder.gep(llargs, [ll.Constant(lli32, index)])
self.llbuilder.store(llargslot, llargptr) self.llbuilder.store(llargslot, llargptr)
llargcount = ll.Constant(lli8, len(args)) llargcount = ll.Constant(lli8, len(args))
llisreturn = ll.Constant(lli1, False) llisreturn = ll.Constant(lli1, False)
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
self.llbuilder.call(self.llbuiltin("subkernel_send_message"), return self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
[llsid, llisreturn, lldest, llargcount, lltagptr, llargs]) [llid, llisreturn, lldest, llargcount, lltagptr, llargs])
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
return llsid
def _build_subkernel_return(self, insn): def _build_subkernel_return(self, insn):
# builds a remote return. # builds a remote return.

View File

@ -121,14 +121,14 @@ class Core:
def compile(self, function, args, kwargs, set_result=None, def compile(self, function, args, kwargs, set_result=None,
attribute_writeback=True, print_as_rpc=True, attribute_writeback=True, print_as_rpc=True,
target=None, destination=0, subkernel_arg_types=[], target=None, destination=0, subkernel_arg_types=[],
subkernels={}): old_embedding_map=None):
try: try:
engine = _DiagnosticEngine(all_errors_are_fatal=True) engine = _DiagnosticEngine(all_errors_are_fatal=True)
stitcher = Stitcher(engine=engine, core=self, dmgr=self.dmgr, stitcher = Stitcher(engine=engine, core=self, dmgr=self.dmgr,
print_as_rpc=print_as_rpc, print_as_rpc=print_as_rpc,
destination=destination, subkernel_arg_types=subkernel_arg_types, destination=destination, subkernel_arg_types=subkernel_arg_types,
subkernels=subkernels) old_embedding_map=old_embedding_map)
stitcher.stitch_call(function, args, kwargs, set_result) stitcher.stitch_call(function, args, kwargs, set_result)
stitcher.finalize() stitcher.finalize()
@ -182,7 +182,7 @@ class Core:
self.compile(subkernel_fn, self_arg, {}, attribute_writeback=False, self.compile(subkernel_fn, self_arg, {}, attribute_writeback=False,
print_as_rpc=False, target=target, destination=destination, print_as_rpc=False, target=target, destination=destination,
subkernel_arg_types=subkernel_arg_types.get(sid, []), subkernel_arg_types=subkernel_arg_types.get(sid, []),
subkernels=subkernels) old_embedding_map=embedding_map)
if object_map.has_rpc(): if object_map.has_rpc():
raise ValueError("Subkernel must not use RPC") raise ValueError("Subkernel must not use RPC")
return destination, kernel_library, object_map return destination, kernel_library, object_map