diff --git a/libasync/src/smoltcp/tcp_stream.rs b/libasync/src/smoltcp/tcp_stream.rs index 9fb6b60..8c1d136 100644 --- a/libasync/src/smoltcp/tcp_stream.rs +++ b/libasync/src/smoltcp/tcp_stream.rs @@ -19,6 +19,7 @@ use smoltcp::{ use crate::task; use super::Sockets; + /// References a smoltcp TcpSocket pub struct TcpStream { handle: SocketHandle, @@ -281,6 +282,87 @@ impl Drop for TcpStream { } } +// copied from artiq/runtime/sched.rs with modifications for async polling +pub struct TcpListener { + handle: SocketHandle, + buffer_size: usize, + endpoint: IpEndpoint, +} + +impl TcpListener { + fn new_socket_handle(buffer_size: usize) -> SocketHandle { + fn uninit_vec(size: usize) -> Vec { + let mut result = Vec::with_capacity(size); + unsafe { + result.set_len(size); + } + result + } + let rx_buffer = TcpSocketBuffer::new(uninit_vec(rx_bufsize)); + let tx_buffer = TcpSocketBuffer::new(uninit_vec(tx_bufsize)); + let socket = TcpSocket::new(rx_buffer, tx_buffer); + let handle = Sockets::instance().sockets.borrow_mut() + .add(socket); + handle + } + + pub fn new(buffer_size: usize) -> TcpListener { + TcpListener { + handle: Self::new_socket_handle(buffer_size), + buffer_size, + endpoint: IpEndpoint::default(), + } + } + + fn with_lower(&self, f: F) -> R + where + F: FnOnce(SocketRef) -> R, + { + let mut sockets = Sockets::instance().sockets.borrow_mut(); + let mut socket_ref = sockets.get::(self.handle); + f(socket_ref) + } + + pub fn can_accept(&self) -> bool { + self.with_lower(|s| s.is_active()) + } + + pub fn listen>(&self, endpoint: T) -> Result<(), Error> { + let endpoint = endpoint.into(); + self.with_lower(|s| s.listen(endpoint)) + .map(|| { + self.endpoint.set(endpoint) + () + }) + .map_err(|err| err.into()) + } + + pub async fn accept(&self) -> Result { + let handle = self.handle.get(); + + let stream = TcpStream { handle }; + + poll_stream!(&stream, (), |socket| { + if socket.state() != TcpState::Listen { + Poll::Ready(()) + } else { + Poll::Pending + } + }).await; + + self.handle.set(Self::new_lower(self.buffer_size.get())); + match self.listen(self.endpoint.get()) { + Ok(()) => (), + _ => unreachable!() + } + Ok(stream) + } + + pub fn close(&self) { + self.with_lower(|mut s| s.close()) + } +} + fn socket_is_handhshaking(socket: &SocketRef) -> bool { match socket.state() { TcpState::SynSent | TcpState::SynReceived =>