From 3ec9788eb1199ca9da7cc52703fc3911aad036ca Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 19 Jul 2020 16:07:55 +0800 Subject: [PATCH] proto_async: always consume one byte in recv --- src/runtime/src/proto_async.rs | 106 ++++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 36 deletions(-) diff --git a/src/runtime/src/proto_async.rs b/src/runtime/src/proto_async.rs index 953c6f2e..f02e8c1c 100644 --- a/src/runtime/src/proto_async.rs +++ b/src/runtime/src/proto_async.rs @@ -5,23 +5,39 @@ use core::cell::RefCell; use libboard_zynq::smoltcp; use libasync::smoltcp::TcpStream; -// TODO: use byteorder, make it more like libio - type Result = core::result::Result; +enum RecvState { + NeedsMore(usize, T), // bytes consumed so far, partial result + Completed(T), // final result +} + pub async fn expect(stream: &TcpStream, pattern: &[u8]) -> Result { - stream.recv(|buf| { - for (i, b) in buf.iter().enumerate() { - if *b == pattern[i] { - if i + 1 == pattern.len() { - return Poll::Ready((i + 1, Ok(true))); + let mut state = RecvState::NeedsMore(0, true); + loop { + state = stream.recv(|buf| { + let mut consumed = 0; + if let RecvState::NeedsMore(mut cur_index, _) = state { + for b in buf.iter() { + consumed += 1; + if *b == pattern[cur_index] { + if cur_index + 1 == pattern.len() { + return Poll::Ready((consumed, RecvState::Completed(true))); + } + } else { + return Poll::Ready((consumed, RecvState::Completed(false))); + } + cur_index += 1; } + Poll::Ready((consumed, RecvState::NeedsMore(cur_index, true))) } else { - return Poll::Ready((i + 1, Ok(false))); + unreachable!(); } + }).await?; + if let RecvState::Completed(result) = state { + return Ok(result); } - Poll::Pending - }).await? + } } pub async fn read_bool(stream: &TcpStream) -> Result { @@ -37,37 +53,55 @@ pub async fn read_i8(stream: &TcpStream) -> Result { } pub async fn read_i32(stream: &TcpStream) -> Result { - Ok(stream.recv(|buf| { - if buf.len() >= 4 { - let value = - ((buf[0] as i32) << 24) - | ((buf[1] as i32) << 16) - | ((buf[2] as i32) << 8) - | (buf[3] as i32); - Poll::Ready((4, value)) - } else { - Poll::Pending + let mut state = RecvState::NeedsMore(0, 0); + loop { + state = stream.recv(|buf| { + let mut consumed = 0; + if let RecvState::NeedsMore(mut cur_index, mut cur_value) = state { + for b in buf.iter() { + consumed += 1; + cur_index += 1; + cur_value <<= 8; + cur_value |= *b as i32; + if cur_index == 4 { + return Poll::Ready((consumed, RecvState::Completed(cur_value))); + } + } + Poll::Ready((consumed, RecvState::NeedsMore(cur_index, cur_value))) + } else { + unreachable!(); + } + }).await?; + if let RecvState::Completed(result) = state { + return Ok(result); } - }).await?) + } } pub async fn read_i64(stream: &TcpStream) -> Result { - Ok(stream.recv(|buf| { - if buf.len() >= 8 { - let value = - ((buf[0] as i64) << 56) - | ((buf[1] as i64) << 48) - | ((buf[2] as i64) << 40) - | ((buf[3] as i64) << 32) - | ((buf[4] as i64) << 24) - | ((buf[5] as i64) << 16) - | ((buf[6] as i64) << 8) - | (buf[7] as i64); - Poll::Ready((8, value)) - } else { - Poll::Pending + let mut state = RecvState::NeedsMore(0, 0); + loop { + state = stream.recv(|buf| { + let mut consumed = 0; + if let RecvState::NeedsMore(mut cur_index, mut cur_value) = state { + for b in buf.iter() { + consumed += 1; + cur_index += 1; + cur_value <<= 8; + cur_value |= *b as i64; + if cur_index == 8 { + return Poll::Ready((consumed, RecvState::Completed(cur_value))); + } + } + Poll::Ready((consumed, RecvState::NeedsMore(cur_index, cur_value))) + } else { + unreachable!(); + } + }).await?; + if let RecvState::Completed(result) = state { + return Ok(result); } - }).await?) + } } pub async fn read_chunk(stream: &TcpStream, destination: &mut [u8]) -> Result<()> {