master: support optional args

pull/256/head
mwojcik 2023-09-21 17:31:49 +08:00
parent 4b3c9a3d08
commit c696fd826f
4 changed files with 40 additions and 23 deletions

View File

@ -448,22 +448,23 @@ async fn handle_run_kernel(
#[cfg(has_drtio)] #[cfg(has_drtio)]
kernel::Message::SubkernelMsgRecvRequest { id, timeout } => { kernel::Message::SubkernelMsgRecvRequest { id, timeout } => {
let message_received = subkernel::message_await(id, timeout, timer).await; let message_received = subkernel::message_await(id, timeout, timer).await;
let status = match message_received { let (status, count) = match message_received {
Ok(_) => kernel::SubkernelStatus::NoError, Ok(ref message) => (kernel::SubkernelStatus::NoError, message.count),
Err(SubkernelError::Timeout) => kernel::SubkernelStatus::Timeout, Err(SubkernelError::Timeout) => (kernel::SubkernelStatus::Timeout, 0),
Err(SubkernelError::IncorrectState) => kernel::SubkernelStatus::IncorrectState, Err(SubkernelError::IncorrectState) => (kernel::SubkernelStatus::IncorrectState, 0),
Err(SubkernelError::CommLost) => kernel::SubkernelStatus::CommLost, Err(SubkernelError::CommLost) => (kernel::SubkernelStatus::CommLost, 0),
Err(_) => kernel::SubkernelStatus::OtherError, Err(_) => (kernel::SubkernelStatus::OtherError, 0),
}; };
control control
.borrow_mut() .borrow_mut()
.tx .tx
.async_send(kernel::Message::SubkernelMsgRecvReply { status: status }) .async_send(kernel::Message::SubkernelMsgRecvReply { status: status, count: count })
.await; .await;
if let Ok((tag, data)) = message_received { if let Ok(message) = message_received {
// receive code almost identical to RPC recv, except we are not reading from a stream // receive code almost identical to RPC recv, except we are not reading from a stream
let mut reader = Cursor::new(data); let mut reader = Cursor::new(message.data);
let mut tag: [u8; 1] = [tag]; let mut tag: [u8; 1] = [message.tag];
let mut i = 0;
loop { loop {
// kernel has to consume all arguments in the whole message // kernel has to consume all arguments in the whole message
let slot = match fast_recv(&mut control.borrow_mut().rx).await { let slot = match fast_recv(&mut control.borrow_mut().rx).await {
@ -493,10 +494,12 @@ async fn handle_run_kernel(
.tx .tx
.async_send(kernel::Message::RpcRecvReply(Ok(0))) .async_send(kernel::Message::RpcRecvReply(Ok(0)))
.await; .await;
match reader.read_u8() { i += 1;
Ok(0) | Err(_) => break, // reached the end of data, we're done if i < count {
Ok(t) => tag[0] = t, // update the tag for next read tag[0] = reader.read_u8()?;
}; } else {
break;
}
} }
} }
} }

View File

@ -116,6 +116,7 @@ pub enum Message {
#[cfg(has_drtio)] #[cfg(has_drtio)]
SubkernelMsgRecvReply { SubkernelMsgRecvReply {
status: SubkernelStatus, status: SubkernelStatus,
count: u8,
}, },
} }

View File

@ -51,18 +51,20 @@ pub extern "C" fn await_finish(id: u32, timeout: u64) {
} }
} }
pub extern "C" fn send_message(id: u32, tag: &CSlice<u8>, data: *const *const ()) { pub extern "C" fn send_message(id: u32, count: u8, tag: &CSlice<u8>, data: *const *const ()) {
let mut buffer = Vec::<u8>::new(); let mut buffer = Vec::<u8>::new();
send_args(&mut buffer, 0, tag.as_ref(), data).expect("RPC encoding failed"); send_args(&mut buffer, 0, tag.as_ref(), data).expect("RPC encoding failed");
// overwrite service tag, include how many tags are in the message
buffer[3] = count;
unsafe { unsafe {
KERNEL_CHANNEL_1TO0.as_mut().unwrap().send(Message::SubkernelMsgSend { KERNEL_CHANNEL_1TO0.as_mut().unwrap().send(Message::SubkernelMsgSend {
id: id, id: id,
data: buffer[4..].to_vec(), data: buffer[3..].to_vec(),
}); });
} }
} }
pub extern "C" fn await_message(id: u32, timeout: u64) { pub extern "C" fn await_message(id: u32, timeout: u64, min: u8, max: u8) {
unsafe { unsafe {
KERNEL_CHANNEL_1TO0 KERNEL_CHANNEL_1TO0
.as_mut() .as_mut()
@ -75,18 +77,27 @@ pub extern "C" fn await_message(id: u32, timeout: u64) {
match unsafe { KERNEL_CHANNEL_0TO1.as_mut().unwrap() }.recv() { match unsafe { KERNEL_CHANNEL_0TO1.as_mut().unwrap() }.recv() {
Message::SubkernelMsgRecvReply { Message::SubkernelMsgRecvReply {
status: SubkernelStatus::NoError, status: SubkernelStatus::NoError,
} => (), count
} => {
if min > count || count > max {
artiq_raise!("SubkernelError", "Received more or less arguments than required")
}
},
Message::SubkernelMsgRecvReply { Message::SubkernelMsgRecvReply {
status: SubkernelStatus::IncorrectState, status: SubkernelStatus::IncorrectState,
..
} => artiq_raise!("SubkernelError", "Subkernel not running"), } => artiq_raise!("SubkernelError", "Subkernel not running"),
Message::SubkernelMsgRecvReply { Message::SubkernelMsgRecvReply {
status: SubkernelStatus::Timeout, status: SubkernelStatus::Timeout,
..
} => artiq_raise!("SubkernelError", "Subkernel timed out"), } => artiq_raise!("SubkernelError", "Subkernel timed out"),
Message::SubkernelMsgRecvReply { Message::SubkernelMsgRecvReply {
status: SubkernelStatus::CommLost, status: SubkernelStatus::CommLost,
..
} => artiq_raise!("SubkernelError", "Lost communication with satellite"), } => artiq_raise!("SubkernelError", "Lost communication with satellite"),
Message::SubkernelMsgRecvReply { Message::SubkernelMsgRecvReply {
status: SubkernelStatus::OtherError, status: SubkernelStatus::OtherError,
..
} => artiq_raise!("SubkernelError", "An error occurred during subkernel operation"), } => artiq_raise!("SubkernelError", "An error occurred during subkernel operation"),
_ => panic!("expected SubkernelMsgRecvReply after SubkernelMsgRecvRequest"), _ => panic!("expected SubkernelMsgRecvReply after SubkernelMsgRecvRequest"),
} }

View File

@ -209,8 +209,9 @@ pub async fn await_finish(
} }
} }
struct Message { pub struct Message {
from_id: u32, from_id: u32,
pub count: u8,
pub tag: u8, pub tag: u8,
pub data: Vec<u8>, pub data: Vec<u8>,
} }
@ -234,8 +235,9 @@ pub async fn message_handle_incoming(id: u32, last: bool, length: usize, data: &
id, id,
Message { Message {
from_id: id, from_id: id,
tag: data[0], count: data[0],
data: data[1..length].to_vec(), tag: data[1],
data: data[2..length].to_vec(),
}, },
); );
} }
@ -249,7 +251,7 @@ pub async fn message_handle_incoming(id: u32, last: bool, length: usize, data: &
} }
} }
pub async fn message_await(id: u32, timeout: u64, timer: GlobalTimer) -> Result<(u8, Vec<u8>), Error> { pub async fn message_await(id: u32, timeout: u64, timer: GlobalTimer) -> Result<Message, Error> {
match SUBKERNELS.async_lock().await.get(&id).unwrap().state { match SUBKERNELS.async_lock().await.get(&id).unwrap().state {
SubkernelState::Finished { SubkernelState::Finished {
status: FinishStatus::CommLost, status: FinishStatus::CommLost,
@ -265,7 +267,7 @@ pub async fn message_await(id: u32, timeout: u64, timer: GlobalTimer) -> Result<
let msg = &message_queue[i]; let msg = &message_queue[i];
if msg.from_id == id { if msg.from_id == id {
let message = message_queue.remove(i); let message = message_queue.remove(i);
return Ok((message.tag, message.data)); return Ok(message);
} }
} }
} }