forked from M-Labs/artiq
compiler: support free subkernel message passing
This commit is contained in:
parent
7d3bcc7cac
commit
0ba0330b53
@ -253,6 +253,12 @@ def fn_subkernel_await():
|
||||
def fn_subkernel_preload():
|
||||
return types.TBuiltinFunction("subkernel_preload")
|
||||
|
||||
def fn_subkernel_send():
|
||||
return types.TBuiltinFunction("subkernel_send")
|
||||
|
||||
def fn_subkernel_recv():
|
||||
return types.TBuiltinFunction("subkernel_recv")
|
||||
|
||||
# Accessors
|
||||
|
||||
def is_none(typ):
|
||||
|
@ -47,8 +47,13 @@ class SpecializedFunction:
|
||||
return hash((self.instance_type, self.host_function))
|
||||
|
||||
|
||||
class SubkernelMessageType:
|
||||
def __init__(self, name, value_type):
|
||||
self.name = name
|
||||
self.value_type = value_type
|
||||
|
||||
class EmbeddingMap:
|
||||
def __init__(self, subkernels={}):
|
||||
def __init__(self, old_embedding_map=None):
|
||||
self.object_current_key = 0
|
||||
self.object_forward_map = {}
|
||||
self.object_reverse_map = {}
|
||||
@ -65,12 +70,21 @@ class EmbeddingMap:
|
||||
self.str_forward_map = {}
|
||||
self.str_reverse_map = {}
|
||||
|
||||
# mapping `name` to object ID
|
||||
self.subkernel_message_map = {}
|
||||
|
||||
# subkernels: dict of ID: function, just like object_forward_map
|
||||
# allow the embedding map to be aware of subkernels from other kernels
|
||||
for key, obj_ref in subkernels.items():
|
||||
if not old_embedding_map is None:
|
||||
for key, obj_ref in old_embedding_map.subkernels().items():
|
||||
self.object_forward_map[key] = obj_ref
|
||||
obj_id = id(obj_ref)
|
||||
self.object_reverse_map[obj_id] = key
|
||||
for msg_id, msg_type in old_embedding_map.subkernel_messages().items():
|
||||
self.object_forward_map[msg_id] = msg_type
|
||||
obj_id = id(msg_type)
|
||||
self.subkernel_message_map[msg_type.name] = msg_id
|
||||
self.object_reverse_map[obj_id] = msg_id
|
||||
|
||||
self.preallocate_runtime_exception_names(["RuntimeError",
|
||||
"RTIOUnderflow",
|
||||
@ -174,7 +188,7 @@ class EmbeddingMap:
|
||||
self.object_current_key += 1
|
||||
while self.object_forward_map.get(self.object_current_key):
|
||||
# make sure there's no collisions with previously inserted subkernels
|
||||
# their identifiers must be consistent between kernels/subkernels
|
||||
# their identifiers must be consistent across all kernels/subkernels
|
||||
self.object_current_key += 1
|
||||
|
||||
self.object_forward_map[self.object_current_key] = obj_ref
|
||||
@ -189,7 +203,7 @@ class EmbeddingMap:
|
||||
obj_ref = self.object_forward_map[obj_id]
|
||||
if isinstance(obj_ref, (pytypes.FunctionType, pytypes.MethodType,
|
||||
pytypes.BuiltinFunctionType, pytypes.ModuleType,
|
||||
SpecializedFunction)):
|
||||
SpecializedFunction, SubkernelMessageType)):
|
||||
continue
|
||||
elif isinstance(obj_ref, type):
|
||||
_, obj_typ = self.type_map[obj_ref]
|
||||
@ -205,6 +219,20 @@ class EmbeddingMap:
|
||||
subkernels[k] = v
|
||||
return subkernels
|
||||
|
||||
def store_subkernel_message(self, name, value_type):
|
||||
if name in self.subkernel_message_map:
|
||||
msg_id = self.subkernel_message_map[name]
|
||||
else:
|
||||
msg_id = self.store_object(SubkernelMessageType(name, value_type))
|
||||
self.subkernel_message_map[name] = msg_id
|
||||
return msg_id, self.retrieve_object(msg_id)
|
||||
|
||||
def subkernel_messages(self):
|
||||
messages = {}
|
||||
for name, msg_id in self.subkernel_message_map.items():
|
||||
messages[msg_id] = self.retrieve_object(msg_id)
|
||||
return messages
|
||||
|
||||
def has_rpc(self):
|
||||
return any(filter(
|
||||
lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) and \
|
||||
@ -802,7 +830,7 @@ class TypedtreeHasher(algorithm.Visitor):
|
||||
return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields))
|
||||
|
||||
class Stitcher:
|
||||
def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[], subkernels={}):
|
||||
def __init__(self, core, dmgr, engine=None, print_as_rpc=True, destination=0, subkernel_arg_types=[], old_embedding_map=None):
|
||||
self.core = core
|
||||
self.dmgr = dmgr
|
||||
if engine is None:
|
||||
@ -824,7 +852,7 @@ class Stitcher:
|
||||
|
||||
self.functions = {}
|
||||
|
||||
self.embedding_map = EmbeddingMap(subkernels)
|
||||
self.embedding_map = EmbeddingMap(old_embedding_map)
|
||||
self.value_map = defaultdict(lambda: [])
|
||||
self.definitely_changed = False
|
||||
|
||||
|
@ -59,4 +59,6 @@ def globals():
|
||||
# ARTIQ subkernel utility functions
|
||||
"subkernel_await": builtins.fn_subkernel_await(),
|
||||
"subkernel_preload": builtins.fn_subkernel_preload(),
|
||||
"subkernel_send": builtins.fn_subkernel_send(),
|
||||
"subkernel_recv": builtins.fn_subkernel_recv(),
|
||||
}
|
||||
|
@ -2559,6 +2559,42 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
sid = ir.Constant(fn.sid, builtins.TInt32())
|
||||
dest = ir.Constant(fn.destination, builtins.TInt32())
|
||||
return self.append(ir.Builtin("subkernel_preload", [sid, dest], builtins.TNone()))
|
||||
elif types.is_builtin(typ, "subkernel_send"):
|
||||
if len(node.args) == 3 and len(node.keywords) == 0:
|
||||
dest = self.visit(node.args[0])
|
||||
name = node.args[1].s
|
||||
value = self.visit(node.args[2])
|
||||
else:
|
||||
assert False
|
||||
msg_id, msg = self.embedding_map.store_subkernel_message(name, value.type)
|
||||
msg_id = ir.Constant(msg_id, builtins.TInt32())
|
||||
if value.type != msg.value_type:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type mismatch for subkernel message '{name}', receiver expects {recv} while sending {send}",
|
||||
{"name": name, "recv": msg.value_type, "send": value.type},
|
||||
node.loc)
|
||||
self.engine.process(diag)
|
||||
return self.append(ir.Builtin("subkernel_send", [msg_id, dest, value], builtins.TNone()))
|
||||
elif types.is_builtin(typ, "subkernel_recv"):
|
||||
if len(node.args) == 2 and len(node.keywords) == 0:
|
||||
name = node.args[0].s
|
||||
vartype = node.args[1].value
|
||||
timeout = ir.Constant(10_000, builtins.TInt64())
|
||||
elif len(node.args) == 3 and len(node.keywords) == 0:
|
||||
name = node.args[0].s
|
||||
vartype = node.args[1].value
|
||||
timeout = self.visit(node.args[2])
|
||||
else:
|
||||
assert False
|
||||
msg_id, msg = self.embedding_map.store_subkernel_message(name, vartype)
|
||||
msg_id = ir.Constant(msg_id, builtins.TInt32())
|
||||
if vartype != msg.value_type:
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type mismatch for subkernel message '{name}', receiver expects {recv} while sending {send}",
|
||||
{"name": name, "recv": vartype, "send": msg.value_type},
|
||||
node.loc)
|
||||
self.engine.process(diag)
|
||||
return self.append(ir.Builtin("subkernel_recv", [msg_id, timeout], vartype))
|
||||
elif types.is_exn_constructor(typ):
|
||||
return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args])
|
||||
elif types.is_constructor(typ):
|
||||
|
@ -1343,6 +1343,57 @@ class Inferencer(algorithm.Visitor):
|
||||
node.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "subkernel_send"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("subkernel_send(dest: numpy.int?, name: str, value: V) -> None"),
|
||||
]
|
||||
self._unify(node.type, builtins.TNone(),
|
||||
node.loc, None)
|
||||
if len(node.args) == 3:
|
||||
arg0 = node.args[0]
|
||||
if types.is_var(arg0.type):
|
||||
pass # undetermined yet
|
||||
else:
|
||||
if builtins.is_int(arg0.type):
|
||||
self._unify(arg0.type, builtins.TInt8(),
|
||||
arg0.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
arg1 = node.args[1]
|
||||
self._unify(arg1.type, builtins.TStr(),
|
||||
arg1.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "subkernel_recv"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("subkernel_recv(name: str, value_type: type) -> value_type"),
|
||||
valid_form("subkernel_recv(name: str, value_type: type, timeout: numpy.int64) -> value_type"),
|
||||
]
|
||||
if 2 <= len(node.args) <= 3:
|
||||
arg0 = node.args[0]
|
||||
if types.is_var(arg0.type):
|
||||
pass
|
||||
else:
|
||||
self._unify(arg0.type, builtins.TStr(),
|
||||
arg0.loc, None)
|
||||
arg1 = node.args[1]
|
||||
if types.is_var(arg1.type):
|
||||
pass
|
||||
else:
|
||||
self._unify(node.type, arg1.value,
|
||||
node.loc, None)
|
||||
if len(node.args) == 3:
|
||||
arg2 = node.args[2]
|
||||
if types.is_var(arg2.type):
|
||||
pass
|
||||
elif builtins.is_int(arg2.type):
|
||||
# promote to TInt64
|
||||
self._unify(arg2.type, builtins.TInt64(),
|
||||
arg2.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
@ -1420,6 +1420,20 @@ class LLVMIRGenerator:
|
||||
lldest = ll.Constant(lli8, insn.operands[1].value)
|
||||
return self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 0)],
|
||||
name="subkernel.preload")
|
||||
elif insn.op == "subkernel_send":
|
||||
llmsgid = self.map(insn.operands[0])
|
||||
lldest = self.map(insn.operands[1])
|
||||
return self._build_subkernel_message(llmsgid, lldest, [insn.operands[2]])
|
||||
elif insn.op == "subkernel_recv":
|
||||
llmsgid = self.map(insn.operands[0])
|
||||
lltimeout = self.map(insn.operands[1])
|
||||
lltagptr = self._build_subkernel_tags([insn.type])
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_await_message"),
|
||||
[llmsgid, lltimeout, lltagptr, ll.Constant(lli8, 1), ll.Constant(lli8, 1)],
|
||||
name="subkernel.await.message")
|
||||
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
|
||||
name="subkernel.arg.stack")
|
||||
return self._build_rpc_recv(insn.type, llstackptr)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@ -1580,11 +1594,8 @@ class LLVMIRGenerator:
|
||||
self.llbuilder.branch(llnormalblock)
|
||||
return llret
|
||||
|
||||
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
|
||||
llservice = ll.Constant(lli32, fun_type.service)
|
||||
|
||||
def _build_arg_tag(self, args, call_type):
|
||||
tag = b""
|
||||
|
||||
for arg in args:
|
||||
def arg_error_handler(typ):
|
||||
printer = types.TypePrinter()
|
||||
@ -1593,12 +1604,18 @@ class LLVMIRGenerator:
|
||||
{"type": printer.name(typ)},
|
||||
arg.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type {type} is not supported in remote procedure calls",
|
||||
{"type": printer.name(arg.type)},
|
||||
"type {type} is not supported in {call_type} calls",
|
||||
{"type": printer.name(arg.type), "call_type": call_type},
|
||||
arg.loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
tag += ir.rpc_tag(arg.type, arg_error_handler)
|
||||
tag += b":"
|
||||
return tag
|
||||
|
||||
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
|
||||
llservice = ll.Constant(lli32, fun_type.service)
|
||||
|
||||
tag = self._build_arg_tag(args, call_type="remote procedure")
|
||||
|
||||
def ret_error_handler(typ):
|
||||
printer = types.TypePrinter()
|
||||
@ -1662,36 +1679,25 @@ class LLVMIRGenerator:
|
||||
def _build_subkernel_call(self, fun_loc, fun_type, args):
|
||||
llsid = ll.Constant(lli32, fun_type.sid)
|
||||
lldest = ll.Constant(lli8, fun_type.destination)
|
||||
tag = b""
|
||||
|
||||
for arg in args:
|
||||
def arg_error_handler(typ):
|
||||
printer = types.TypePrinter()
|
||||
note = diagnostic.Diagnostic("note",
|
||||
"value of type {type}",
|
||||
{"type": printer.name(typ)},
|
||||
arg.loc)
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"type {type} is not supported in subkernel calls",
|
||||
{"type": printer.name(arg.type)},
|
||||
arg.loc, notes=[note])
|
||||
self.engine.process(diag)
|
||||
tag += ir.rpc_tag(arg.type, arg_error_handler)
|
||||
tag += b":"
|
||||
|
||||
# run the kernel first
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_load_run"), [llsid, lldest, ll.Constant(lli1, 1)])
|
||||
|
||||
# arg sent in the same vein as RPC
|
||||
if args:
|
||||
# only send args if there's anything to send, 'self' is excluded
|
||||
self._build_subkernel_message(llsid, lldest, args)
|
||||
|
||||
return llsid
|
||||
|
||||
def _build_subkernel_message(self, llid, lldest, args):
|
||||
# args (or messages) are sent in the same vein as RPC
|
||||
tag = self._build_arg_tag(args, call_type="subkernel")
|
||||
|
||||
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [],
|
||||
name="subkernel.stack")
|
||||
|
||||
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
|
||||
lltagptr = self.llbuilder.alloca(lltag.type)
|
||||
self.llbuilder.store(lltag, lltagptr)
|
||||
|
||||
if args:
|
||||
# only send args if there's anything to send, 'self' is excluded
|
||||
llargs = self.llbuilder.alloca(llptr, ll.Constant(lli32, len(args)),
|
||||
name="subkernel.args")
|
||||
for index, arg in enumerate(args):
|
||||
@ -1711,12 +1717,9 @@ class LLVMIRGenerator:
|
||||
llargcount = ll.Constant(lli8, len(args))
|
||||
|
||||
llisreturn = ll.Constant(lli1, False)
|
||||
|
||||
self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
|
||||
[llsid, llisreturn, lldest, llargcount, lltagptr, llargs])
|
||||
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
|
||||
|
||||
return llsid
|
||||
return self.llbuilder.call(self.llbuiltin("subkernel_send_message"),
|
||||
[llid, llisreturn, lldest, llargcount, lltagptr, llargs])
|
||||
|
||||
def _build_subkernel_return(self, insn):
|
||||
# builds a remote return.
|
||||
|
@ -121,14 +121,14 @@ class Core:
|
||||
def compile(self, function, args, kwargs, set_result=None,
|
||||
attribute_writeback=True, print_as_rpc=True,
|
||||
target=None, destination=0, subkernel_arg_types=[],
|
||||
subkernels={}):
|
||||
old_embedding_map=None):
|
||||
try:
|
||||
engine = _DiagnosticEngine(all_errors_are_fatal=True)
|
||||
|
||||
stitcher = Stitcher(engine=engine, core=self, dmgr=self.dmgr,
|
||||
print_as_rpc=print_as_rpc,
|
||||
destination=destination, subkernel_arg_types=subkernel_arg_types,
|
||||
subkernels=subkernels)
|
||||
old_embedding_map=old_embedding_map)
|
||||
stitcher.stitch_call(function, args, kwargs, set_result)
|
||||
stitcher.finalize()
|
||||
|
||||
@ -182,7 +182,7 @@ class Core:
|
||||
self.compile(subkernel_fn, self_arg, {}, attribute_writeback=False,
|
||||
print_as_rpc=False, target=target, destination=destination,
|
||||
subkernel_arg_types=subkernel_arg_types.get(sid, []),
|
||||
subkernels=subkernels)
|
||||
old_embedding_map=embedding_map)
|
||||
if object_map.has_rpc():
|
||||
raise ValueError("Subkernel must not use RPC")
|
||||
return destination, kernel_library, object_map
|
||||
|
Loading…
Reference in New Issue
Block a user