From 99899e6657a0fdc261652517bc5d7d4cb17193d8 Mon Sep 17 00:00:00 2001 From: Harry Ho Date: Thu, 11 Mar 2021 17:15:39 +0800 Subject: [PATCH] nal: Fix read/write not pushing erroneous socket back to the stack * Based on quartiq's minimq as of https://github.com/quartiq/minimq/commit/933687c2e4bc8a4d972de9a4d1508b0b554a8b38 * In minimq applications, a socket is expected to be returned when `nal::TcpStack::open()` is called * `MqttClient::read()`/`write()` takes away the TCP socket handle (wrapped as an `Option`) from its `RefCell`, and then calls `nal::TcpStack::read()`/`write()`; if NAL returns `nb::Error`, then the MQTT client will propagate and return the error, leaving `None` behind * Afterwards, when `MqttClient::socket_is_connected()` gets called (e.g. while polling the interface), it will detect that the socket handle is `None`, and attempt to call `nal::TcpStack::open()` * Since `open()` pops a socket from the array (`unused_handles`), when implementing this NAL the socket should have been pushed back to the stack, i.e. by `close()`; this prevents any future calls of `open()` from returning `NetworkError::NoSocket` due to emptiness of the array of socket handles --- src/nal.rs | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/nal.rs b/src/nal.rs index 7a0ed9f..5bb9fd3 100644 --- a/src/nal.rs +++ b/src/nal.rs @@ -236,6 +236,7 @@ where } fn write(&self, socket: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result { + let mut write_error = false; let mut non_queued_bytes = &buffer[..]; while non_queued_bytes.len() != 0 { let result = { @@ -254,9 +255,18 @@ where // Process the unwritten bytes again, if any non_queued_bytes = &non_queued_bytes[num_bytes..] } - Err(_) => return Err(nb::Error::Other(NetworkError::WriteFailure)), + Err(_) => { + write_error = true; + break; + } } } + if write_error { + // Close the socket to push it back to the array, for + // re-opening the socket in the future + self.close(*socket)?; + return Err(nb::Error::Other(NetworkError::WriteFailure)) + } Ok(buffer.len()) } @@ -268,13 +278,19 @@ where // Enqueue received bytes into the TCP socket buffer self.update()?; self.poll()?; - let mut sockets = self.sockets.borrow_mut(); - let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); - let result = socket.recv_slice(buffer); - match result { - Ok(num_bytes) => Ok(num_bytes), - Err(_) => Err(nb::Error::Other(NetworkError::ReadFailure)), + { + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + let result = socket.recv_slice(buffer); + match result { + Ok(num_bytes) => { return Ok(num_bytes) }, + Err(_) => {}, + } } + // Close the socket to push it back to the array, for + // re-opening the socket in the future + self.close(*socket)?; + Err(nb::Error::Other(NetworkError::ReadFailure)) } fn close(&self, socket: Self::TcpSocket) -> Result<(), Self::Error> {