From d9e50bbcb60ab403037d23dde40e90e00bbdac0a Mon Sep 17 00:00:00 2001 From: Harry Ho Date: Fri, 5 Mar 2021 13:27:26 +0800 Subject: [PATCH] nal: Prevent looping until the stack successfully connects to remote * `NetworkStack::connect()`: * Add timeout for connection attempt * Now returns the socket at TCP ESTABLISHED or CLOSED states, or after connection timeout * Split `NetworkStack::update()` into `update()` (for controlling the clock) and `poll()` (for polling the smoltcp EthernetInterface) * Also remove option `auto_time_update`; the main application is responsible for what values `embedded_time::clock::Clock::try_now()` should return --- src/nal.rs | 166 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 104 insertions(+), 62 deletions(-) diff --git a/src/nal.rs b/src/nal.rs index 0cb71da..7a0ed9f 100644 --- a/src/nal.rs +++ b/src/nal.rs @@ -1,5 +1,5 @@ use core::cell::RefCell; -use core::convert::TryFrom; +use core::convert::TryInto; use heapless::{consts, Vec}; use embedded_nal as nal; use nal::nb; @@ -40,7 +40,8 @@ where unused_handles: RefCell>, time_ms: RefCell, last_update_instant: RefCell>>, - clock: IntClock + clock: IntClock, + connection_timeout_ms: u32, } impl<'a, SPI, NSS, IntClock> NetworkStack<'a, SPI, NSS, IntClock> @@ -49,7 +50,12 @@ where NSS: OutputPin, IntClock: time::Clock, { - pub fn new(interface: NetworkInterface, sockets: net::socket::SocketSet<'a>, clock: IntClock) -> Self { + pub fn new( + interface: NetworkInterface, + sockets: net::socket::SocketSet<'a>, + clock: IntClock, + connection_timeout_ms: u32, + ) -> Self { let mut unused_handles: Vec = Vec::new(); for socket in sockets.iter() { unused_handles.push(socket.handle()).unwrap(); @@ -62,34 +68,55 @@ where time_ms: RefCell::new(0), last_update_instant: RefCell::new(None), clock, + connection_timeout_ms, } } - // Include auto_time_update to allow Instant::now() to not be called - // Instant::now() is not safe to call in `init()` context - pub fn update(&self, auto_time_update: bool) -> Result { - if auto_time_update { - // Check if it is the first time the stack has updated the time itself - let now = match *self.last_update_instant.borrow() { - // If it is the first time, do not advance time - // Simply store the current instant to initiate time updating - None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?, - // If it was updated before, advance time and update last_update_instant - Some(instant) => { - // Calculate elapsed time - let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?; - let duration = now.checked_duration_since(&instant).ok_or(NetworkError::TimeFault)?; - let duration_ms = time::duration::Milliseconds::::try_from(duration).map_err(|_| NetworkError::TimeFault)?; - // Adjust duration into ms (note: decimal point truncated) - self.advance_time(*duration_ms.integer()); - now + // Initiate or advance the timer, and return the duration in ms as u32. + fn update(&self) -> Result { + let mut duration_ms: u32 = 0; + // Check if it is the first time the stack has updated the time itself + let now = match *self.last_update_instant.borrow() { + // If it is the first time, do not advance time + // Simply store the current instant to initiate time updating + None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?, + // If it was updated before, advance time and update last_update_instant + Some(instant) => { + // Calculate elapsed time + let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?; + let mut duration = now.checked_duration_since(&instant); + // Normally, the wrapping clock should produce a valid duration. + // However, if `now` is earlier than `instant` (e.g. because the main + // application cannot get a valid epoch time during initialisation, + // we should still produce a duration that is just 1ms. + if duration.is_none() { + self.time_ms.replace(0); + duration = Some(Milliseconds(1_u32) + .to_generic::(IntClock::SCALING_FACTOR) + .map_err(|_| NetworkError::TimeFault)?); } - }; - self.last_update_instant.replace(Some(now)); - } + let duration_ms_time: Milliseconds = duration.unwrap().try_into() + .map_err(|_| NetworkError::TimeFault)?; + duration_ms = *duration_ms_time.integer(); + // Adjust duration into ms (note: decimal point truncated) + self.advance_time(duration_ms); + now + } + }; + self.last_update_instant.replace(Some(now)); + Ok(duration_ms) + } + + fn advance_time(&self, duration_ms: u32) { + let time = self.time_ms.borrow().wrapping_add(duration_ms); + self.time_ms.replace(time); + } + + // Poll on the smoltcp interface + fn poll(&self) -> Result { match self.network_interface.borrow_mut().poll( &mut self.sockets.borrow_mut(), - net::time::Instant::from_millis(*self.time_ms.borrow() as i64), + net::time::Instant::from_millis(*self.time_ms.borrow() as u32), ) { Ok(changed) => Ok(!changed), Err(_e) => { @@ -98,11 +125,6 @@ where } } - pub fn advance_time(&self, duration: u32) { - let time = self.time_ms.try_borrow().unwrap().wrapping_add(duration); - self.time_ms.replace(time); - } - fn get_ephemeral_port(&self) -> u16 { // Get the next ephemeral port let current_port = self.next_port.borrow().clone(); @@ -132,20 +154,27 @@ where } } - // Ideally connect is only to be performed in `init()` of `main.rs` - // Calling `Instant::now()` of `rtic::cyccnt` would face correctness issue during `init()` fn connect( &self, socket: Self::TcpSocket, remote: nal::SocketAddr, ) -> Result { - let address = { + { + // If the socket has already been connected, ignore the connection + // request silently. let mut sockets = self.sockets.borrow_mut(); let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); - // If we're already in the process of connecting, ignore the request silently. - if internal_socket.is_open() { - return Ok(socket); + if internal_socket.state() == net::socket::TcpState::Established { + return Ok(socket) } + } + + { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + // abort() instead of close() prevents TcpSocket::connect() from + // raising an error + internal_socket.abort(); match remote.ip() { nal::IpAddr::V4(addr) => { let address = @@ -154,45 +183,56 @@ where .connect((address, remote.port()), self.get_ephemeral_port()) .map_err(|_| NetworkError::ConnectionFailure)?; net::wire::IpAddress::Ipv4(address) - } + }, nal::IpAddr::V6(addr) => { - let address = net::wire::Ipv6Address::from_parts(&addr.segments()[..]); - internal_socket.connect((address, remote.port()), self.get_ephemeral_port()) + let address = + net::wire::Ipv6Address::from_parts(&addr.segments()[..]); + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) .map_err(|_| NetworkError::ConnectionFailure)?; net::wire::IpAddress::Ipv6(address) } } }; - // Blocking connect + + // Loop to wait until the socket is staying established or closed, + // or the connection attempt has timed out. + let mut timeout_ms: u32 = 0; loop { - match self.is_connected(&socket) { - Ok(true) => break, - _ => { - let mut sockets = self.sockets.borrow_mut(); - let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); - // If the connect got ACK->RST, it will end up in Closed TCP state - // Perform reconnection in this case - if internal_socket.state() == net::socket::TcpState::Closed { - internal_socket.close(); - internal_socket - .connect((address, remote.port()), self.get_ephemeral_port()) - .map_err(|_| NetworkError::ConnectionFailure)?; - } + { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + // TCP state at ESTABLISHED means there is connection, so + // simply return the socket. + if internal_socket.state() == net::socket::TcpState::Established { + return Ok(socket) + } + // TCP state at CLOSED implies that the remote rejected connection; + // In this case, abort the connection, and then return the socket + // for re-connection in the future. + if internal_socket.state() == net::socket::TcpState::Closed { + internal_socket.abort(); + // TODO: Return Err(), but would require changes in quartiq/minimq + return Ok(socket) } } - // Avoid using Instant::now() and Advance time manually - self.update(false)?; - { - self.advance_time(1); + // Any TCP states other than CLOSED and ESTABLISHED are considered + // "transient", so this function should keep waiting and let smoltcp poll + // (e.g. for handling echo reqeust/reply) at the same time. + timeout_ms += self.update()?; + self.poll()?; + // Time out, and return the socket for re-connection in the future. + if timeout_ms > self.connection_timeout_ms { + // TODO: Return Err(), but would require changes in quartiq/minimq + return Ok(socket) } } - Ok(socket) } fn is_connected(&self, socket: &Self::TcpSocket) -> Result { let mut sockets = self.sockets.borrow_mut(); - let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); - Ok(socket.may_send() && socket.may_recv()) + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + Ok(internal_socket.state() == net::socket::TcpState::Established) } fn write(&self, socket: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result { @@ -208,7 +248,8 @@ where Ok(num_bytes) => { // In case the buffer is filled up, push bytes into ethernet driver if num_bytes != non_queued_bytes.len() { - self.update(true)?; + self.update()?; + self.poll()?; } // Process the unwritten bytes again, if any non_queued_bytes = &non_queued_bytes[num_bytes..] @@ -225,7 +266,8 @@ where buffer: &mut [u8], ) -> nb::Result { // Enqueue received bytes into the TCP socket buffer - self.update(true)?; + self.update()?; + self.poll()?; let mut sockets = self.sockets.borrow_mut(); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); let result = socket.recv_slice(buffer);