subkernel messages: check for send/recv pairs

This commit is contained in:
mwojcik 2024-02-01 11:09:59 +08:00 committed by Sébastien Bourdeauducq
parent 849b77fbf2
commit 7fee68ede0
3 changed files with 43 additions and 7 deletions

View File

@ -51,6 +51,8 @@ class SubkernelMessageType:
def __init__(self, name, value_type): def __init__(self, name, value_type):
self.name = name self.name = name
self.value_type = value_type self.value_type = value_type
self.send_loc = None
self.recv_loc = None
class EmbeddingMap: class EmbeddingMap:
def __init__(self, old_embedding_map=None): def __init__(self, old_embedding_map=None):
@ -219,20 +221,35 @@ class EmbeddingMap:
subkernels[k] = v subkernels[k] = v
return subkernels return subkernels
def store_subkernel_message(self, name, value_type): def store_subkernel_message(self, name, value_type, function_type, function_loc):
if name in self.subkernel_message_map: if name in self.subkernel_message_map:
msg_id = self.subkernel_message_map[name] msg_id = self.subkernel_message_map[name]
else: else:
msg_id = self.store_object(SubkernelMessageType(name, value_type)) msg_id = self.store_object(SubkernelMessageType(name, value_type))
self.subkernel_message_map[name] = msg_id self.subkernel_message_map[name] = msg_id
return msg_id, self.retrieve_object(msg_id) subkernel_msg = self.retrieve_object(msg_id)
if function_type == "send":
subkernel_msg.send_loc = function_loc
elif function_type == "recv":
subkernel_msg.recv_loc = function_loc
else:
assert False
return msg_id, subkernel_msg
def subkernel_messages(self): def subkernel_messages(self):
messages = {} messages = {}
for name, msg_id in self.subkernel_message_map.items(): for msg_id in self.subkernel_message_map.values():
messages[msg_id] = self.retrieve_object(msg_id) messages[msg_id] = self.retrieve_object(msg_id)
return messages return messages
def subkernel_messages_unpaired(self):
unpaired = []
for msg_id in self.subkernel_message_map.values():
msg_obj = self.retrieve_object(msg_id)
if msg_obj.send_loc is None or msg_obj.recv_loc is None:
unpaired.append(msg_obj)
return unpaired
def has_rpc(self): def has_rpc(self):
return any(filter( return any(filter(
lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) and \ lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) and \

View File

@ -2566,7 +2566,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
value = self.visit(node.args[2]) value = self.visit(node.args[2])
else: else:
assert False assert False
msg_id, msg = self.embedding_map.store_subkernel_message(name, value.type) msg_id, msg = self.embedding_map.store_subkernel_message(name, value.type, "send", node.loc)
msg_id = ir.Constant(msg_id, builtins.TInt32()) msg_id = ir.Constant(msg_id, builtins.TInt32())
if value.type != msg.value_type: if value.type != msg.value_type:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
@ -2586,7 +2586,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
timeout = self.visit(node.args[2]) timeout = self.visit(node.args[2])
else: else:
assert False assert False
msg_id, msg = self.embedding_map.store_subkernel_message(name, vartype) msg_id, msg = self.embedding_map.store_subkernel_message(name, vartype, "recv", node.loc)
msg_id = ir.Constant(msg_id, builtins.TInt32()) msg_id = ir.Constant(msg_id, builtins.TInt32())
if vartype != msg.value_type: if vartype != msg.value_type:
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",

View File

@ -195,15 +195,34 @@ class Core:
for sid, subkernel_fn in subkernels.items(): for sid, subkernel_fn in subkernels.items():
if sid in subkernels_compiled: if sid in subkernels_compiled:
continue continue
destination, kernel_library, sub_embedding_map = \ destination, kernel_library, embedding_map = \
self.compile_subkernel(sid, subkernel_fn, embedding_map, self.compile_subkernel(sid, subkernel_fn, embedding_map,
args, subkernel_arg_types, subkernels) args, subkernel_arg_types, subkernels)
self.comm.upload_subkernel(kernel_library, sid, destination) self.comm.upload_subkernel(kernel_library, sid, destination)
new_subkernels.update(sub_embedding_map.subkernels()) new_subkernels.update(embedding_map.subkernels())
subkernels_compiled.append(sid) subkernels_compiled.append(sid)
if new_subkernels == subkernels: if new_subkernels == subkernels:
break break
subkernels.update(new_subkernels) subkernels.update(new_subkernels)
# check for messages without a send/recv pair
unpaired_messages = embedding_map.subkernel_messages_unpaired()
if unpaired_messages:
for unpaired_message in unpaired_messages:
engine = _DiagnosticEngine(all_errors_are_fatal=False)
# errors are non-fatal in order to display
# all unpaired message errors before raising an excption
if unpaired_message.send_loc is None:
diag = diagnostic.Diagnostic("error",
"subkernel message '{name}' only has a receiver but no sender",
{"name": unpaired_message.name},
unpaired_message.recv_loc)
else:
diag = diagnostic.Diagnostic("error",
"subkernel message '{name}' only has a sender but no receiver",
{"name": unpaired_message.name},
unpaired_message.send_loc)
engine.process(diag)
raise ValueError("Found subkernel message(s) without a full send/recv pair")
def precompile(self, function, *args, **kwargs): def precompile(self, function, *args, **kwargs):