diff --git a/libasync/src/smoltcp/tcp_stream.rs b/libasync/src/smoltcp/tcp_stream.rs index f6846da..ac5ae04 100644 --- a/libasync/src/smoltcp/tcp_stream.rs +++ b/libasync/src/smoltcp/tcp_stream.rs @@ -3,7 +3,9 @@ //! TODO: implement futures AsyncRead/AsyncWrite/Stream/Sink interfaces use core::{ + cell::RefCell, future::Future, + ops::DerefMut, pin::Pin, task::{Context, Poll}, }; @@ -108,17 +110,18 @@ impl TcpStream { /// number of bytes it consumed, and a user-defined return value of type R. pub async fn recv(&self, f: F) -> Result where - F: Fn(&[u8]) -> (usize, R), + F: FnMut(&[u8]) -> (usize, R), { - struct Recv<'a, F: FnOnce(&[u8]) -> (usize, R), R> { + struct Recv<'a, F: FnMut(&[u8]) -> (usize, R), R> { stream: &'a TcpStream, - f: F, + f: RefCell, } - impl<'a, F: Fn(&[u8]) -> (usize, R), R> Future for Recv<'a, F, R> { + impl<'a, F: FnMut(&[u8]) -> (usize, R), R> Future for Recv<'a, F, R> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut f = self.f.borrow_mut(); let result = self.stream.with_socket(|mut socket| { if socket_is_handhshaking(&socket) { return Ok(Poll::Pending); @@ -126,7 +129,7 @@ impl TcpStream { socket.recv(|buf| { if buf.len() > 0 { - let (amount, result) = (self.f)(buf); + let (amount, result) = (f.deref_mut())(buf); assert!(amount > 0); (amount, Poll::Ready(Ok(result))) } else { @@ -150,7 +153,7 @@ impl TcpStream { Recv { stream: self, - f, + f: RefCell::new(f), }.await }