diff --git a/src/runtime/src/comms.rs b/src/runtime/src/comms.rs index 5e01d31..b081821 100644 --- a/src/runtime/src/comms.rs +++ b/src/runtime/src/comms.rs @@ -448,22 +448,23 @@ async fn handle_run_kernel( #[cfg(has_drtio)] kernel::Message::SubkernelMsgRecvRequest { id, timeout } => { let message_received = subkernel::message_await(id, timeout, timer).await; - let status = match message_received { - Ok(_) => kernel::SubkernelStatus::NoError, - Err(SubkernelError::Timeout) => kernel::SubkernelStatus::Timeout, - Err(SubkernelError::IncorrectState) => kernel::SubkernelStatus::IncorrectState, - Err(SubkernelError::CommLost) => kernel::SubkernelStatus::CommLost, - Err(_) => kernel::SubkernelStatus::OtherError, + let (status, count) = match message_received { + Ok(ref message) => (kernel::SubkernelStatus::NoError, message.count), + Err(SubkernelError::Timeout) => (kernel::SubkernelStatus::Timeout, 0), + Err(SubkernelError::IncorrectState) => (kernel::SubkernelStatus::IncorrectState, 0), + Err(SubkernelError::CommLost) => (kernel::SubkernelStatus::CommLost, 0), + Err(_) => (kernel::SubkernelStatus::OtherError, 0), }; control .borrow_mut() .tx - .async_send(kernel::Message::SubkernelMsgRecvReply { status: status }) + .async_send(kernel::Message::SubkernelMsgRecvReply { status: status, count: count }) .await; - if let Ok((tag, data)) = message_received { + if let Ok(message) = message_received { // receive code almost identical to RPC recv, except we are not reading from a stream - let mut reader = Cursor::new(data); - let mut tag: [u8; 1] = [tag]; + let mut reader = Cursor::new(message.data); + let mut tag: [u8; 1] = [message.tag]; + let mut i = 0; loop { // kernel has to consume all arguments in the whole message let slot = match fast_recv(&mut control.borrow_mut().rx).await { @@ -493,10 +494,12 @@ async fn handle_run_kernel( .tx .async_send(kernel::Message::RpcRecvReply(Ok(0))) .await; - match reader.read_u8() { - Ok(0) | Err(_) => break, // reached the end of data, we're done - Ok(t) => tag[0] = t, // update the tag for next read - }; + i += 1; + if i < count { + tag[0] = reader.read_u8()?; + } else { + break; + } } } } diff --git a/src/runtime/src/kernel/mod.rs b/src/runtime/src/kernel/mod.rs index 592d5d9..258e0f5 100644 --- a/src/runtime/src/kernel/mod.rs +++ b/src/runtime/src/kernel/mod.rs @@ -116,6 +116,7 @@ pub enum Message { #[cfg(has_drtio)] SubkernelMsgRecvReply { status: SubkernelStatus, + count: u8, }, } diff --git a/src/runtime/src/kernel/subkernel.rs b/src/runtime/src/kernel/subkernel.rs index 4d7b528..f5fbf49 100644 --- a/src/runtime/src/kernel/subkernel.rs +++ b/src/runtime/src/kernel/subkernel.rs @@ -51,18 +51,20 @@ pub extern "C" fn await_finish(id: u32, timeout: u64) { } } -pub extern "C" fn send_message(id: u32, tag: &CSlice, data: *const *const ()) { +pub extern "C" fn send_message(id: u32, count: u8, tag: &CSlice, data: *const *const ()) { let mut buffer = Vec::::new(); send_args(&mut buffer, 0, tag.as_ref(), data).expect("RPC encoding failed"); + // overwrite service tag, include how many tags are in the message + buffer[3] = count; unsafe { KERNEL_CHANNEL_1TO0.as_mut().unwrap().send(Message::SubkernelMsgSend { id: id, - data: buffer[4..].to_vec(), + data: buffer[3..].to_vec(), }); } } -pub extern "C" fn await_message(id: u32, timeout: u64) { +pub extern "C" fn await_message(id: u32, timeout: u64, min: u8, max: u8) { unsafe { KERNEL_CHANNEL_1TO0 .as_mut() @@ -75,18 +77,27 @@ pub extern "C" fn await_message(id: u32, timeout: u64) { match unsafe { KERNEL_CHANNEL_0TO1.as_mut().unwrap() }.recv() { Message::SubkernelMsgRecvReply { status: SubkernelStatus::NoError, - } => (), + count + } => { + if min > count || count > max { + artiq_raise!("SubkernelError", "Received more or less arguments than required") + } + }, Message::SubkernelMsgRecvReply { status: SubkernelStatus::IncorrectState, + .. } => artiq_raise!("SubkernelError", "Subkernel not running"), Message::SubkernelMsgRecvReply { status: SubkernelStatus::Timeout, + .. } => artiq_raise!("SubkernelError", "Subkernel timed out"), Message::SubkernelMsgRecvReply { status: SubkernelStatus::CommLost, + .. } => artiq_raise!("SubkernelError", "Lost communication with satellite"), Message::SubkernelMsgRecvReply { status: SubkernelStatus::OtherError, + .. } => artiq_raise!("SubkernelError", "An error occurred during subkernel operation"), _ => panic!("expected SubkernelMsgRecvReply after SubkernelMsgRecvRequest"), } diff --git a/src/runtime/src/subkernel.rs b/src/runtime/src/subkernel.rs index 3c86a59..63c24af 100644 --- a/src/runtime/src/subkernel.rs +++ b/src/runtime/src/subkernel.rs @@ -209,8 +209,9 @@ pub async fn await_finish( } } -struct Message { +pub struct Message { from_id: u32, + pub count: u8, pub tag: u8, pub data: Vec, } @@ -234,8 +235,9 @@ pub async fn message_handle_incoming(id: u32, last: bool, length: usize, data: & id, Message { from_id: id, - tag: data[0], - data: data[1..length].to_vec(), + count: data[0], + tag: data[1], + data: data[2..length].to_vec(), }, ); } @@ -249,7 +251,7 @@ pub async fn message_handle_incoming(id: u32, last: bool, length: usize, data: & } } -pub async fn message_await(id: u32, timeout: u64, timer: GlobalTimer) -> Result<(u8, Vec), Error> { +pub async fn message_await(id: u32, timeout: u64, timer: GlobalTimer) -> Result { match SUBKERNELS.async_lock().await.get(&id).unwrap().state { SubkernelState::Finished { status: FinishStatus::CommLost, @@ -265,7 +267,7 @@ pub async fn message_await(id: u32, timeout: u64, timer: GlobalTimer) -> Result< let msg = &message_queue[i]; if msg.from_id == id { let message = message_queue.remove(i); - return Ok((message.tag, message.data)); + return Ok(message); } } }