nal: Fix loops & socket handle management #7

Merged
harry merged 6 commits from fix-nal into master 2021-04-12 09:28:56 +08:00
1 changed files with 150 additions and 81 deletions

View File

@ -1,5 +1,5 @@
use core::cell::RefCell; use core::cell::RefCell;
use core::convert::TryFrom; use core::convert::TryInto;
use heapless::{consts, Vec}; use heapless::{consts, Vec};
use embedded_nal as nal; use embedded_nal as nal;
use nal::nb; use nal::nb;
@ -40,7 +40,8 @@ where
unused_handles: RefCell<Vec<net::socket::SocketHandle, consts::U16>>, unused_handles: RefCell<Vec<net::socket::SocketHandle, consts::U16>>,
time_ms: RefCell<u32>, time_ms: RefCell<u32>,
last_update_instant: RefCell<Option<time::Instant<IntClock>>>, last_update_instant: RefCell<Option<time::Instant<IntClock>>>,
clock: IntClock clock: IntClock,
connection_timeout_ms: u32,
} }
impl<'a, SPI, NSS, IntClock> NetworkStack<'a, SPI, NSS, IntClock> impl<'a, SPI, NSS, IntClock> NetworkStack<'a, SPI, NSS, IntClock>
@ -49,7 +50,12 @@ where
NSS: OutputPin, NSS: OutputPin,
IntClock: time::Clock<T = u32>, IntClock: time::Clock<T = u32>,
{ {
pub fn new(interface: NetworkInterface<SPI, NSS>, sockets: net::socket::SocketSet<'a>, clock: IntClock) -> Self { pub fn new(
interface: NetworkInterface<SPI, NSS>,
sockets: net::socket::SocketSet<'a>,
clock: IntClock,
connection_timeout_ms: u32,
astro marked this conversation as resolved

Can you try to use an embedded_time::duration type wherever possible?

Can you try to use an `embedded_time::duration` type wherever possible?

I defined the connection timeout parameter as a u32 because it is used to compare with the network stack struct's time_ms "time now", which is also stored as u32. If I'm following @occheung 's logic correctly, we want time_ms to stay as u32 rather than a embedded_time::duration::Milliseconds because every time the Ethernet interface is busy polling, it has to borrow the "time now" value and conversion of the timestamp among u32, the embedded_time duration type and the smoltcp duration type might consume more CPU cycles. Plus, I made update() return a u32 representing the number of microseconds advanced, and so counting the total duration would be simpler.

I defined the connection timeout parameter as a `u32` because it is used to compare with the network stack struct's `time_ms` "time now", which is also stored as `u32`. If I'm following @occheung 's logic correctly, we want `time_ms` to stay as `u32` rather than a `embedded_time::duration::Milliseconds` because every time the Ethernet interface is busy polling, it has to borrow the "time now" value and conversion of the timestamp among `u32`, the `embedded_time` duration type and the `smoltcp` duration type might consume more CPU cycles. Plus, I made `update()` return a `u32` representing the number of microseconds advanced, and so counting the total duration would be simpler.
) -> Self {
let mut unused_handles: Vec<net::socket::SocketHandle, consts::U16> = Vec::new(); let mut unused_handles: Vec<net::socket::SocketHandle, consts::U16> = Vec::new();
for socket in sockets.iter() { for socket in sockets.iter() {
unused_handles.push(socket.handle()).unwrap(); unused_handles.push(socket.handle()).unwrap();
@ -62,34 +68,55 @@ where
time_ms: RefCell::new(0), time_ms: RefCell::new(0),
last_update_instant: RefCell::new(None), last_update_instant: RefCell::new(None),
clock, clock,
connection_timeout_ms,
} }
} }
// Include auto_time_update to allow Instant::now() to not be called // Initiate or advance the timer, and return the duration in ms as u32.
// Instant::now() is not safe to call in `init()` context fn update(&self) -> Result<u32, NetworkError> {
pub fn update(&self, auto_time_update: bool) -> Result<bool, NetworkError> { let mut duration_ms: u32 = 0;
if auto_time_update { // Check if it is the first time the stack has updated the time itself
// Check if it is the first time the stack has updated the time itself let now = match *self.last_update_instant.borrow() {
let now = match *self.last_update_instant.borrow() { // If it is the first time, do not advance time
// If it is the first time, do not advance time // Simply store the current instant to initiate time updating
// Simply store the current instant to initiate time updating None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?,
None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?, // If it was updated before, advance time and update last_update_instant
// If it was updated before, advance time and update last_update_instant Some(instant) => {
Some(instant) => { // Calculate elapsed time
// Calculate elapsed time let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?;
let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?; let mut duration = now.checked_duration_since(&instant);
let duration = now.checked_duration_since(&instant).ok_or(NetworkError::TimeFault)?; // Normally, the wrapping clock should produce a valid duration.
let duration_ms = time::duration::Milliseconds::<u32>::try_from(duration).map_err(|_| NetworkError::TimeFault)?; // However, if `now` is earlier than `instant` (e.g. because the main
// Adjust duration into ms (note: decimal point truncated) // application cannot get a valid epoch time during initialisation,
self.advance_time(*duration_ms.integer()); // we should still produce a duration that is just 1ms.
now if duration.is_none() {
self.time_ms.replace(0);
duration = Some(Milliseconds(1_u32)
.to_generic::<u32>(IntClock::SCALING_FACTOR)
.map_err(|_| NetworkError::TimeFault)?);
} }
}; let duration_ms_time: Milliseconds<u32> = duration.unwrap().try_into()
self.last_update_instant.replace(Some(now)); .map_err(|_| NetworkError::TimeFault)?;
} duration_ms = *duration_ms_time.integer();
// Adjust duration into ms (note: decimal point truncated)
self.advance_time(duration_ms);
now
}
};
self.last_update_instant.replace(Some(now));
Ok(duration_ms)
}
fn advance_time(&self, duration_ms: u32) {
let time = self.time_ms.borrow().wrapping_add(duration_ms);
self.time_ms.replace(time);
}
// Poll on the smoltcp interface
fn poll(&self) -> Result<bool, NetworkError> {
match self.network_interface.borrow_mut().poll( match self.network_interface.borrow_mut().poll(
&mut self.sockets.borrow_mut(), &mut self.sockets.borrow_mut(),
net::time::Instant::from_millis(*self.time_ms.borrow() as i64), net::time::Instant::from_millis(*self.time_ms.borrow() as u32),
) { ) {
Ok(changed) => Ok(!changed), Ok(changed) => Ok(!changed),
Err(_e) => { Err(_e) => {
@ -98,11 +125,6 @@ where
} }
} }
pub fn advance_time(&self, duration: u32) {
let time = self.time_ms.try_borrow().unwrap().wrapping_add(duration);
self.time_ms.replace(time);
}
fn get_ephemeral_port(&self) -> u16 { fn get_ephemeral_port(&self) -> u16 {
// Get the next ephemeral port // Get the next ephemeral port
let current_port = self.next_port.borrow().clone(); let current_port = self.next_port.borrow().clone();
@ -124,122 +146,169 @@ where
Some(handle) => { Some(handle) => {
// Abort any active connections on the handle. // Abort any active connections on the handle.
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
internal_socket.abort(); socket.abort();
Ok(handle) Ok(handle)
} }
None => Err(NetworkError::NoSocket), None => Err(NetworkError::NoSocket),
} }
} }
// Ideally connect is only to be performed in `init()` of `main.rs`
// Calling `Instant::now()` of `rtic::cyccnt` would face correctness issue during `init()`
fn connect( fn connect(
&self, &self,
socket: Self::TcpSocket, handle: Self::TcpSocket,
remote: nal::SocketAddr, remote: nal::SocketAddr,
) -> Result<Self::TcpSocket, Self::Error> { ) -> Result<Self::TcpSocket, Self::Error> {
let address = { {
// If the socket has already been connected, ignore the connection
// request silently.
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
// If we're already in the process of connecting, ignore the request silently. if socket.state() == net::socket::TcpState::Established {
if internal_socket.is_open() { return Ok(handle)
return Ok(socket);
} }
}
{
let mut sockets = self.sockets.borrow_mut();
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
// abort() instead of close() prevents TcpSocket::connect() from
// raising an error
socket.abort();
match remote.ip() { match remote.ip() {
nal::IpAddr::V4(addr) => { nal::IpAddr::V4(addr) => {
let address = let address =
net::wire::Ipv4Address::from_bytes(&addr.octets()[..]); net::wire::Ipv4Address::from_bytes(&addr.octets()[..]);
internal_socket socket
.connect((address, remote.port()), self.get_ephemeral_port()) .connect((address, remote.port()), self.get_ephemeral_port())
.map_err(|_| NetworkError::ConnectionFailure)?; .map_err(|_| NetworkError::ConnectionFailure)?;
net::wire::IpAddress::Ipv4(address) net::wire::IpAddress::Ipv4(address)
} },
nal::IpAddr::V6(addr) => { nal::IpAddr::V6(addr) => {
let address = net::wire::Ipv6Address::from_parts(&addr.segments()[..]); let address =
internal_socket.connect((address, remote.port()), self.get_ephemeral_port()) net::wire::Ipv6Address::from_parts(&addr.segments()[..]);
socket
.connect((address, remote.port()), self.get_ephemeral_port())
.map_err(|_| NetworkError::ConnectionFailure)?; .map_err(|_| NetworkError::ConnectionFailure)?;
net::wire::IpAddress::Ipv6(address) net::wire::IpAddress::Ipv6(address)
} }
} }
}; };
// Blocking connect // Blocking connect
// Loop to wait until the socket is staying established or closed,
// or the connection attempt has timed out.
let mut timeout_ms: u32 = 0;
loop { loop {
match self.is_connected(&socket) { {
Ok(true) => break, let mut sockets = self.sockets.borrow_mut();
_ => { let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
let mut sockets = self.sockets.borrow_mut(); // TCP state at ESTABLISHED means there is connection, so
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); // simply return the socket.
// If the connect got ACK->RST, it will end up in Closed TCP state if socket.state() == net::socket::TcpState::Established {
// Perform reconnection in this case return Ok(handle)
if internal_socket.state() == net::socket::TcpState::Closed { }
internal_socket.close(); // TCP state at CLOSED implies that the remote rejected connection;
internal_socket // In this case, abort the connection, and then return the socket
.connect((address, remote.port()), self.get_ephemeral_port()) // for re-connection in the future.
.map_err(|_| NetworkError::ConnectionFailure)?; if socket.state() == net::socket::TcpState::Closed {
} socket.abort();
// TODO: Return Err(), but would require changes in quartiq/minimq
return Ok(handle)
} }
} }
// Avoid using Instant::now() and Advance time manually
self.update(false)?; // Any TCP states other than CLOSED and ESTABLISHED are considered
{ // "transient", so this function should keep waiting and let smoltcp poll
self.advance_time(1); // (e.g. for handling echo reqeust/reply) at the same time.
timeout_ms += self.update()?;

While all the other steps are very nicely documented, this fairly important self.poll() disappears to my eyes. How about a pair of newlines around it?

While all the other steps are very nicely documented, this fairly important `self.poll()` disappears to my eyes. How about a pair of newlines around it?

Right. Before, poll() was simply part of update(), and I decided to separate them simply for the sake of clarity.

What do you think: is it good to keep them separated?

Right. Before, `poll()` was simply part of `update()`, and I decided to separate them simply for the sake of clarity. What do you think: is it good to keep them separated?

Smaller chunks with separated logic are always better.

Smaller chunks with separated logic are always better.

In 232a08f110, I added a pair of newlines, one before the comment line (that explains polling), and one after self.poll(). Would that be good?

In 232a08f11012fd749b68882140a8b9dc1a653363, I added a pair of newlines, one before the comment line (that explains polling), and one after `self.poll()`. Would that be good?
self.poll()?;
// Time out, and return the socket for re-connection in the future.
if timeout_ms > self.connection_timeout_ms {
// TODO: Return Err(), but would require changes in quartiq/minimq
return Ok(handle)
} }
} }
Ok(socket)
} }
fn is_connected(&self, socket: &Self::TcpSocket) -> Result<bool, Self::Error> { fn is_connected(&self, handle: &Self::TcpSocket) -> Result<bool, Self::Error> {
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*handle);
Ok(socket.may_send() && socket.may_recv()) Ok(socket.state() == net::socket::TcpState::Established)
} }
fn write(&self, socket: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result<usize, Self::Error> { fn write(&self, handle: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result<usize, Self::Error> {
let mut write_error = false;
let mut non_queued_bytes = &buffer[..]; let mut non_queued_bytes = &buffer[..];
while non_queued_bytes.len() != 0 { while non_queued_bytes.len() != 0 {
let result = { let result = {
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*handle);
let result = socket.send_slice(non_queued_bytes); let result = socket.send_slice(non_queued_bytes);
result result
}; };
match result { match result {
Ok(num_bytes) => { Ok(num_bytes) => {
// If the buffer is completely filled, close the socket and
// return an error
if num_bytes == 0 {
write_error = true;
break;
}
// In case the buffer is filled up, push bytes into ethernet driver // In case the buffer is filled up, push bytes into ethernet driver
if num_bytes != non_queued_bytes.len() { if num_bytes != non_queued_bytes.len() {
self.update(true)?; self.update()?;
self.poll()?;
} }
// Process the unwritten bytes again, if any // Process the unwritten bytes again, if any
non_queued_bytes = &non_queued_bytes[num_bytes..] non_queued_bytes = &non_queued_bytes[num_bytes..]
} }
Err(_) => return Err(nb::Error::Other(NetworkError::WriteFailure)), Err(_) => {
write_error = true;
break;
}
} }
} }
if write_error {
// Close the socket to push it back to the array, for
// re-opening the socket in the future
self.close(*handle)?;
return Err(nb::Error::Other(NetworkError::WriteFailure))
}
Ok(buffer.len()) Ok(buffer.len())
} }
fn read( fn read(
&self, &self,
socket: &mut Self::TcpSocket, handle: &mut Self::TcpSocket,
buffer: &mut [u8], buffer: &mut [u8],
) -> nb::Result<usize, Self::Error> { ) -> nb::Result<usize, Self::Error> {
// Enqueue received bytes into the TCP socket buffer // Enqueue received bytes into the TCP socket buffer
self.update(true)?; self.update()?;
let mut sockets = self.sockets.borrow_mut(); self.poll()?;
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); {
let result = socket.recv_slice(buffer); let mut sockets = self.sockets.borrow_mut();
match result { let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*handle);
Ok(num_bytes) => Ok(num_bytes), let result = socket.recv_slice(buffer);
Err(_) => Err(nb::Error::Other(NetworkError::ReadFailure)), match result {
Ok(num_bytes) => { return Ok(num_bytes) },
Err(_) => {},
}
} }
// Close the socket to push it back to the array, for
// re-opening the socket in the future
self.close(*handle)?;
Err(nb::Error::Other(NetworkError::ReadFailure))
} }
fn close(&self, socket: Self::TcpSocket) -> Result<(), Self::Error> { fn close(&self, handle: Self::TcpSocket) -> Result<(), Self::Error> {
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
internal_socket.close(); socket.close();
self.unused_handles.borrow_mut().push(socket).unwrap(); let mut unused_handles = self.unused_handles.borrow_mut();
if unused_handles.iter().find(|&x| *x == handle).is_none() {
unused_handles.push(handle).unwrap();
}
Ok(()) Ok(())
} }
} }