nal: Fix loops & socket handle management #7

Merged
harry merged 6 commits from fix-nal into master 2021-04-12 09:28:56 +08:00
1 changed files with 150 additions and 81 deletions

View File

@ -1,5 +1,5 @@
use core::cell::RefCell; use core::cell::RefCell;
use core::convert::TryFrom; use core::convert::TryInto;
use heapless::{consts, Vec}; use heapless::{consts, Vec};
use embedded_nal as nal; use embedded_nal as nal;
use nal::nb; use nal::nb;
@ -40,7 +40,8 @@ where
unused_handles: RefCell<Vec<net::socket::SocketHandle, consts::U16>>, unused_handles: RefCell<Vec<net::socket::SocketHandle, consts::U16>>,
time_ms: RefCell<u32>, time_ms: RefCell<u32>,
last_update_instant: RefCell<Option<time::Instant<IntClock>>>, last_update_instant: RefCell<Option<time::Instant<IntClock>>>,
clock: IntClock clock: IntClock,
connection_timeout_ms: u32,
} }
impl<'a, SPI, NSS, IntClock> NetworkStack<'a, SPI, NSS, IntClock> impl<'a, SPI, NSS, IntClock> NetworkStack<'a, SPI, NSS, IntClock>
@ -49,7 +50,12 @@ where
NSS: OutputPin, NSS: OutputPin,
IntClock: time::Clock<T = u32>, IntClock: time::Clock<T = u32>,
{ {
pub fn new(interface: NetworkInterface<SPI, NSS>, sockets: net::socket::SocketSet<'a>, clock: IntClock) -> Self { pub fn new(
interface: NetworkInterface<SPI, NSS>,
sockets: net::socket::SocketSet<'a>,
clock: IntClock,
connection_timeout_ms: u32,
astro marked this conversation as resolved

Can you try to use an embedded_time::duration type wherever possible?

Can you try to use an `embedded_time::duration` type wherever possible?

I defined the connection timeout parameter as a u32 because it is used to compare with the network stack struct's time_ms "time now", which is also stored as u32. If I'm following @occheung 's logic correctly, we want time_ms to stay as u32 rather than a embedded_time::duration::Milliseconds because every time the Ethernet interface is busy polling, it has to borrow the "time now" value and conversion of the timestamp among u32, the embedded_time duration type and the smoltcp duration type might consume more CPU cycles. Plus, I made update() return a u32 representing the number of microseconds advanced, and so counting the total duration would be simpler.

I defined the connection timeout parameter as a `u32` because it is used to compare with the network stack struct's `time_ms` "time now", which is also stored as `u32`. If I'm following @occheung 's logic correctly, we want `time_ms` to stay as `u32` rather than a `embedded_time::duration::Milliseconds` because every time the Ethernet interface is busy polling, it has to borrow the "time now" value and conversion of the timestamp among `u32`, the `embedded_time` duration type and the `smoltcp` duration type might consume more CPU cycles. Plus, I made `update()` return a `u32` representing the number of microseconds advanced, and so counting the total duration would be simpler.
) -> Self {
let mut unused_handles: Vec<net::socket::SocketHandle, consts::U16> = Vec::new(); let mut unused_handles: Vec<net::socket::SocketHandle, consts::U16> = Vec::new();
for socket in sockets.iter() { for socket in sockets.iter() {
unused_handles.push(socket.handle()).unwrap(); unused_handles.push(socket.handle()).unwrap();
@ -62,34 +68,55 @@ where
time_ms: RefCell::new(0), time_ms: RefCell::new(0),
last_update_instant: RefCell::new(None), last_update_instant: RefCell::new(None),
clock, clock,
connection_timeout_ms,
} }
} }
// Include auto_time_update to allow Instant::now() to not be called // Initiate or advance the timer, and return the duration in ms as u32.
// Instant::now() is not safe to call in `init()` context fn update(&self) -> Result<u32, NetworkError> {
pub fn update(&self, auto_time_update: bool) -> Result<bool, NetworkError> { let mut duration_ms: u32 = 0;
if auto_time_update { // Check if it is the first time the stack has updated the time itself
// Check if it is the first time the stack has updated the time itself let now = match *self.last_update_instant.borrow() {
let now = match *self.last_update_instant.borrow() { // If it is the first time, do not advance time
// If it is the first time, do not advance time // Simply store the current instant to initiate time updating
// Simply store the current instant to initiate time updating None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?,
None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?, // If it was updated before, advance time and update last_update_instant
// If it was updated before, advance time and update last_update_instant Some(instant) => {
Some(instant) => { // Calculate elapsed time
// Calculate elapsed time let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?;
let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?; let mut duration = now.checked_duration_since(&instant);
let duration = now.checked_duration_since(&instant).ok_or(NetworkError::TimeFault)?; // Normally, the wrapping clock should produce a valid duration.
let duration_ms = time::duration::Milliseconds::<u32>::try_from(duration).map_err(|_| NetworkError::TimeFault)?; // However, if `now` is earlier than `instant` (e.g. because the main
// Adjust duration into ms (note: decimal point truncated) // application cannot get a valid epoch time during initialisation,
self.advance_time(*duration_ms.integer()); // we should still produce a duration that is just 1ms.
now if duration.is_none() {
self.time_ms.replace(0);
duration = Some(Milliseconds(1_u32)
.to_generic::<u32>(IntClock::SCALING_FACTOR)
.map_err(|_| NetworkError::TimeFault)?);
} }
}; let duration_ms_time: Milliseconds<u32> = duration.unwrap().try_into()
self.last_update_instant.replace(Some(now)); .map_err(|_| NetworkError::TimeFault)?;
} duration_ms = *duration_ms_time.integer();
// Adjust duration into ms (note: decimal point truncated)
self.advance_time(duration_ms);
now
}
};
self.last_update_instant.replace(Some(now));
Ok(duration_ms)
}
fn advance_time(&self, duration_ms: u32) {
let time = self.time_ms.borrow().wrapping_add(duration_ms);
self.time_ms.replace(time);
}
// Poll on the smoltcp interface
fn poll(&self) -> Result<bool, NetworkError> {
match self.network_interface.borrow_mut().poll( match self.network_interface.borrow_mut().poll(
&mut self.sockets.borrow_mut(), &mut self.sockets.borrow_mut(),
net::time::Instant::from_millis(*self.time_ms.borrow() as i64), net::time::Instant::from_millis(*self.time_ms.borrow() as u32),
) { ) {
Ok(changed) => Ok(!changed), Ok(changed) => Ok(!changed),
Err(_e) => { Err(_e) => {
@ -98,11 +125,6 @@ where
} }
} }
pub fn advance_time(&self, duration: u32) {
let time = self.time_ms.try_borrow().unwrap().wrapping_add(duration);
self.time_ms.replace(time);
}
fn get_ephemeral_port(&self) -> u16 { fn get_ephemeral_port(&self) -> u16 {
// Get the next ephemeral port // Get the next ephemeral port
let current_port = self.next_port.borrow().clone(); let current_port = self.next_port.borrow().clone();
@ -124,122 +146,169 @@ where
Some(handle) => { Some(handle) => {
// Abort any active connections on the handle. // Abort any active connections on the handle.
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
internal_socket.abort(); socket.abort();
Ok(handle) Ok(handle)
} }
None => Err(NetworkError::NoSocket), None => Err(NetworkError::NoSocket),
} }
} }
// Ideally connect is only to be performed in `init()` of `main.rs`
// Calling `Instant::now()` of `rtic::cyccnt` would face correctness issue during `init()`
fn connect( fn connect(
&self, &self,
socket: Self::TcpSocket, handle: Self::TcpSocket,
remote: nal::SocketAddr, remote: nal::SocketAddr,
) -> Result<Self::TcpSocket, Self::Error> { ) -> Result<Self::TcpSocket, Self::Error> {
let address = { {
// If the socket has already been connected, ignore the connection
// request silently.
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
// If we're already in the process of connecting, ignore the request silently. if socket.state() == net::socket::TcpState::Established {
if internal_socket.is_open() { return Ok(handle)
return Ok(socket);
} }
}
{
let mut sockets = self.sockets.borrow_mut();
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);

This isn't "internal" as in "some inner struct". I think it is more obvious to name this to what is behind this SocketRef::T, a tcp_socket.

This isn't "internal" as in "some inner struct". I think it is more obvious to name this to what is behind this `SocketRef::T`, a `tcp_socket`.

Perhaps @occheung meant that the internal_socket is the socket stored "internally" in the smoltcp::socket::SocketSet within the NetworkStack, which returns the socket needed using the socket parameter passed to the method, which is actually a smoltcp::socket::SocketHandle that was previously popped from the array of unused handles, done "externally" by calling open()?

Anyway, since Rust doesn't do matching of names of method parameters between the trait and the implementation, in 66c3aa534f I renamed the socket to handle and internal_socket to socket for clarity.

Perhaps @occheung meant that the `internal_socket` is the socket stored "internally" in the `smoltcp::socket::SocketSet` within the `NetworkStack`, which returns the socket needed using the `socket` parameter passed to the method, which is actually a `smoltcp::socket::SocketHandle` that was previously popped from the array of unused handles, done "externally" by calling `open()`? Anyway, since Rust doesn't do matching of names of method parameters between the trait and the implementation, in 66c3aa534f4b4a7628dae70700017164a9c5fc1d I renamed the `socket` to `handle` and `internal_socket` to `socket` for clarity.
// abort() instead of close() prevents TcpSocket::connect() from
// raising an error
socket.abort();
match remote.ip() { match remote.ip() {
nal::IpAddr::V4(addr) => { nal::IpAddr::V4(addr) => {
let address = let address =
net::wire::Ipv4Address::from_bytes(&addr.octets()[..]); net::wire::Ipv4Address::from_bytes(&addr.octets()[..]);
internal_socket socket
.connect((address, remote.port()), self.get_ephemeral_port()) .connect((address, remote.port()), self.get_ephemeral_port())
.map_err(|_| NetworkError::ConnectionFailure)?; .map_err(|_| NetworkError::ConnectionFailure)?;
net::wire::IpAddress::Ipv4(address) net::wire::IpAddress::Ipv4(address)
} },
nal::IpAddr::V6(addr) => { nal::IpAddr::V6(addr) => {
let address = net::wire::Ipv6Address::from_parts(&addr.segments()[..]); let address =
internal_socket.connect((address, remote.port()), self.get_ephemeral_port()) net::wire::Ipv6Address::from_parts(&addr.segments()[..]);
socket
.connect((address, remote.port()), self.get_ephemeral_port())
.map_err(|_| NetworkError::ConnectionFailure)?; .map_err(|_| NetworkError::ConnectionFailure)?;
net::wire::IpAddress::Ipv6(address) net::wire::IpAddress::Ipv6(address)
} }
} }
}; };
// Blocking connect // Blocking connect

Please retain the wording blocking as it is exactly what this is.

Please retain the wording **blocking** as it is exactly what this is.

Resolved in 232a08f110 by copying this original line of comment verbatim.

Resolved in 232a08f11012fd749b68882140a8b9dc1a653363 by copying this original line of comment verbatim.
// Loop to wait until the socket is staying established or closed,
// or the connection attempt has timed out.
let mut timeout_ms: u32 = 0;
loop { loop {
match self.is_connected(&socket) { {
Ok(true) => break, let mut sockets = self.sockets.borrow_mut();
_ => { let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
let mut sockets = self.sockets.borrow_mut(); // TCP state at ESTABLISHED means there is connection, so
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); // simply return the socket.
// If the connect got ACK->RST, it will end up in Closed TCP state if socket.state() == net::socket::TcpState::Established {
// Perform reconnection in this case return Ok(handle)
if internal_socket.state() == net::socket::TcpState::Closed { }
internal_socket.close(); // TCP state at CLOSED implies that the remote rejected connection;
internal_socket // In this case, abort the connection, and then return the socket
.connect((address, remote.port()), self.get_ephemeral_port()) // for re-connection in the future.
.map_err(|_| NetworkError::ConnectionFailure)?; if socket.state() == net::socket::TcpState::Closed {
} socket.abort();
// TODO: Return Err(), but would require changes in quartiq/minimq
return Ok(handle)
} }
} }
// Avoid using Instant::now() and Advance time manually
self.update(false)?; // Any TCP states other than CLOSED and ESTABLISHED are considered
{ // "transient", so this function should keep waiting and let smoltcp poll
self.advance_time(1); // (e.g. for handling echo reqeust/reply) at the same time.
timeout_ms += self.update()?;

While all the other steps are very nicely documented, this fairly important self.poll() disappears to my eyes. How about a pair of newlines around it?

While all the other steps are very nicely documented, this fairly important `self.poll()` disappears to my eyes. How about a pair of newlines around it?

Right. Before, poll() was simply part of update(), and I decided to separate them simply for the sake of clarity.

What do you think: is it good to keep them separated?

Right. Before, `poll()` was simply part of `update()`, and I decided to separate them simply for the sake of clarity. What do you think: is it good to keep them separated?

Smaller chunks with separated logic are always better.

Smaller chunks with separated logic are always better.

In 232a08f110, I added a pair of newlines, one before the comment line (that explains polling), and one after self.poll(). Would that be good?

In 232a08f11012fd749b68882140a8b9dc1a653363, I added a pair of newlines, one before the comment line (that explains polling), and one after `self.poll()`. Would that be good?
self.poll()?;
// Time out, and return the socket for re-connection in the future.
if timeout_ms > self.connection_timeout_ms {
// TODO: Return Err(), but would require changes in quartiq/minimq
return Ok(handle)
} }
} }
Ok(socket)
} }
fn is_connected(&self, socket: &Self::TcpSocket) -> Result<bool, Self::Error> { fn is_connected(&self, handle: &Self::TcpSocket) -> Result<bool, Self::Error> {
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*handle);
Ok(socket.may_send() && socket.may_recv()) Ok(socket.state() == net::socket::TcpState::Established)
} }
fn write(&self, socket: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result<usize, Self::Error> { fn write(&self, handle: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result<usize, Self::Error> {
let mut write_error = false;
let mut non_queued_bytes = &buffer[..]; let mut non_queued_bytes = &buffer[..];
while non_queued_bytes.len() != 0 { while non_queued_bytes.len() != 0 {
let result = { let result = {
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*handle);
let result = socket.send_slice(non_queued_bytes); let result = socket.send_slice(non_queued_bytes);
result result
}; };
match result { match result {
Ok(num_bytes) => { Ok(num_bytes) => {
// If the buffer is completely filled, close the socket and
// return an error
if num_bytes == 0 {
write_error = true;
break;
}
// In case the buffer is filled up, push bytes into ethernet driver // In case the buffer is filled up, push bytes into ethernet driver
if num_bytes != non_queued_bytes.len() { if num_bytes != non_queued_bytes.len() {
self.update(true)?; self.update()?;
self.poll()?;
} }
// Process the unwritten bytes again, if any // Process the unwritten bytes again, if any
non_queued_bytes = &non_queued_bytes[num_bytes..] non_queued_bytes = &non_queued_bytes[num_bytes..]
} }
Err(_) => return Err(nb::Error::Other(NetworkError::WriteFailure)), Err(_) => {
write_error = true;
break;
}
} }
} }
if write_error {
// Close the socket to push it back to the array, for
// re-opening the socket in the future
self.close(*handle)?;
return Err(nb::Error::Other(NetworkError::WriteFailure))
}
Ok(buffer.len()) Ok(buffer.len())
} }
fn read( fn read(
&self, &self,
socket: &mut Self::TcpSocket, handle: &mut Self::TcpSocket,
buffer: &mut [u8], buffer: &mut [u8],
) -> nb::Result<usize, Self::Error> { ) -> nb::Result<usize, Self::Error> {
// Enqueue received bytes into the TCP socket buffer // Enqueue received bytes into the TCP socket buffer
self.update(true)?; self.update()?;
let mut sockets = self.sockets.borrow_mut(); self.poll()?;
let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); {
let result = socket.recv_slice(buffer); let mut sockets = self.sockets.borrow_mut();
match result { let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*handle);
Ok(num_bytes) => Ok(num_bytes), let result = socket.recv_slice(buffer);
Err(_) => Err(nb::Error::Other(NetworkError::ReadFailure)), match result {
Ok(num_bytes) => { return Ok(num_bytes) },
Err(_) => {},
}
} }
// Close the socket to push it back to the array, for
// re-opening the socket in the future
self.close(*handle)?;
Err(nb::Error::Other(NetworkError::ReadFailure))
} }
fn close(&self, socket: Self::TcpSocket) -> Result<(), Self::Error> { fn close(&self, handle: Self::TcpSocket) -> Result<(), Self::Error> {
let mut sockets = self.sockets.borrow_mut(); let mut sockets = self.sockets.borrow_mut();
let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); let socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle);
internal_socket.close(); socket.close();
self.unused_handles.borrow_mut().push(socket).unwrap(); let mut unused_handles = self.unused_handles.borrow_mut();
if unused_handles.iter().find(|&x| *x == handle).is_none() {
unused_handles.push(handle).unwrap();
}
Ok(()) Ok(())
} }
} }