diff --git a/Cargo.toml b/Cargo.toml index 68a84e9..270be88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,9 @@ cortex-m-rt = { version = "0.6", optional = true } cortex-m-rtic = { version = "0.5.3", optional = true } panic-itm = { version = "0.4", optional = true } log = { version = "0.4", optional = true } +# Optional dependencies for using NAL with somltcp and cortex_m +heapless = { version = "0.5.6", optional = true } +embedded-nal = { version = "0.1.0", optional = true } [features] smoltcp-phy = ["smoltcp"] @@ -31,6 +34,7 @@ smoltcp-phy-all = [ # Example-based features tx_stm32f407 = ["stm32f4xx-hal/stm32f407", "cortex-m", "cortex-m-rtic", "panic-itm", "log"] tcp_stm32f407 = ["stm32f4xx-hal/stm32f407", "cortex-m", "cortex-m-rt", "cortex-m-rtic", "smoltcp-phy-all", "smoltcp/log", "panic-itm", "log"] +nal = [ "smoltcp-phy", "heapless", "embedded-nal", "cortex-m", "cortex-m-rtic" ] default = [] [[example]] diff --git a/src/lib.rs b/src/lib.rs index baf065f..d9b392e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,9 @@ pub mod tx; #[cfg(feature="smoltcp")] pub mod smoltcp_phy; +#[cfg(all(feature="smoltcp", feature="nal"))] +pub mod nal; + /// Max raw frame array size pub const RAW_FRAME_LENGTH_MAX: usize = 1518; diff --git a/src/nal.rs b/src/nal.rs new file mode 100644 index 0000000..8dc06a8 --- /dev/null +++ b/src/nal.rs @@ -0,0 +1,267 @@ +use core::cell::RefCell; +use heapless::{consts, Vec}; +use embedded_nal as nal; +use nal::nb; +use smoltcp as net; +use embedded_hal::{ + blocking::spi::Transfer, + blocking::delay::DelayUs, + digital::v2::OutputPin +}; +use rtic::cyccnt::Instant; + + +#[derive(Debug)] +pub enum NetworkError { + NoSocket, + ConnectionFailure, + ReadFailure, + WriteFailure, + Unsupported, +} + +type NetworkInterface = net::iface::EthernetInterface< + 'static, + 'static, + 'static, + crate::smoltcp_phy::SmoltcpDevice< + crate::SpiEth + >, +>; + +pub struct NetworkStack<'a, 'b, 'c, SPI, NSS, Delay> +where + SPI: 'static + Transfer, + NSS: 'static + OutputPin, + Delay: 'static + DelayUs +{ + network_interface: RefCell>, + sockets: RefCell>, + next_port: RefCell, + unused_handles: RefCell>, + time_ms: RefCell, + last_update_instant: RefCell>, +} + +impl<'a, 'b, 'c, SPI, NSS, Delay> NetworkStack<'a, 'b, 'c, SPI, NSS, Delay> +where + SPI: Transfer, + NSS: OutputPin, + Delay: DelayUs +{ + pub fn new(interface: NetworkInterface, sockets: net::socket::SocketSet<'a, 'b, 'c>) -> Self { + let mut unused_handles: Vec = Vec::new(); + for socket in sockets.iter() { + unused_handles.push(socket.handle()).unwrap(); + } + + NetworkStack { + network_interface: RefCell::new(interface), + sockets: RefCell::new(sockets), + next_port: RefCell::new(49152), + unused_handles: RefCell::new(unused_handles), + time_ms: RefCell::new(0), + last_update_instant: RefCell::new(None), + } + } + + // Include auto_time_update to allow Instant::now() to not be called + // Instant::now() is not safe to call in `init()` context + pub fn update(&self, auto_time_update: bool) -> bool { + if auto_time_update { + // Check if it is the first time the stack has updated the time itself + let now = match *self.last_update_instant.borrow() { + // If it is the first time, do not advance time + // Simply store the current instant to initiate time updating + None => Instant::now(), + + // If it was updated before, advance time and update last_update_instant + Some(instant) => { + // Calculate elapsed time + let now = Instant::now(); + let duration = now.duration_since(instant); + // Adjust duration into ms (note: decimal point truncated) + self.advance_time(duration.as_cycles() / 168_000); + now + } + }; + + self.last_update_instant.replace(Some(now)); + } + + match self.network_interface.borrow_mut().poll( + &mut self.sockets.borrow_mut(), + net::time::Instant::from_millis(*self.time_ms.borrow() as i64), + ) { + Ok(changed) => changed == false, + Err(_e) => { + true + } + } + } + + pub fn advance_time(&self, duration: u32) { + *self.time_ms.borrow_mut() += duration; + } + + fn get_ephemeral_port(&self) -> u16 { + // Get the next ephemeral port + let current_port = self.next_port.borrow().clone(); + + let (next, wrap) = self.next_port.borrow().overflowing_add(1); + *self.next_port.borrow_mut() = if wrap { 49152 } else { next }; + + return current_port; + } +} + +impl<'a, 'b, 'c, SPI, NSS, Delay> nal::TcpStack for NetworkStack<'a, 'b, 'c, SPI, NSS, Delay> +where + SPI: Transfer, + NSS: OutputPin, + Delay: DelayUs +{ + type TcpSocket = net::socket::SocketHandle; + type Error = NetworkError; + + fn open(&self, _mode: nal::Mode) -> Result { + match self.unused_handles.borrow_mut().pop() { + Some(handle) => { + // Abort any active connections on the handle. + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle); + internal_socket.abort(); + + Ok(handle) + } + None => Err(NetworkError::NoSocket), + } + } + + // Ideally connect is only to be performed in `init()` of `main.rs` + // Calling `Instant::now()` of `rtic::cyccnt` would face correctness issue during `init()` + fn connect( + &self, + socket: Self::TcpSocket, + remote: nal::SocketAddr, + ) -> Result { + let address = { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + + // If we're already in the process of connecting, ignore the request silently. + if internal_socket.is_open() { + return Ok(socket); + } + + match remote.ip() { + nal::IpAddr::V4(addr) => { + let octets = addr.octets(); + let address = + net::wire::Ipv4Address::new(octets[0], octets[1], octets[2], octets[3]); + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + address + } + nal::IpAddr::V6(_) => { + // Match W5500 behavior: Reject the use of IPV6 + return Err(NetworkError::Unsupported); + } + } + }; + + // Match W5500 behavior: Poll until connected + loop { + match self.is_connected(&socket) { + Ok(true) => break, + _ => { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + // If the connect got ACK->RST, it will end up in Closed TCP state + // Perform reconnection in this case + // In all other scenario, simply wait for TCP connection to be established + if internal_socket.state() == net::socket::TcpState::Closed { + internal_socket.close(); + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + } + } + } + + // Avoid using Instant::now() and Advance time manually + self.update(false); + + // Delay for 1 ms, minimum time unit of smoltcp + // TODO: Allow clock configuration, if supported in main + cortex_m::asm::delay(168_000_000 / 1_000); + { + self.advance_time(1); + } + } + + Ok(socket) + } + + fn is_connected(&self, socket: &Self::TcpSocket) -> Result { + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + + Ok(socket.may_send() && socket.may_recv()) + } + + fn write(&self, socket: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result { + let mut non_queued_bytes = &buffer[..]; + while non_queued_bytes.len() != 0 { + let result = { + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + let result = socket.send_slice(non_queued_bytes); + result + }; + + match result { + Ok(num_bytes) => { + // In case the buffer is filled up, push bytes into ethernet driver + if num_bytes != non_queued_bytes.len() { + self.update(true); + } + + // Process the unwritten bytes again, if any + non_queued_bytes = &non_queued_bytes[num_bytes..] + } + Err(_) => return Err(nb::Error::Other(NetworkError::WriteFailure)), + } + } + + Ok(buffer.len()) + } + + fn read( + &self, + socket: &mut Self::TcpSocket, + buffer: &mut [u8], + ) -> nb::Result { + // Enqueue received bytes into the TCP socket buffer + self.update(true); + + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + + let result = socket.recv_slice(buffer); + match result { + Ok(num_bytes) => Ok(num_bytes), + Err(_) => Err(nb::Error::Other(NetworkError::ReadFailure)), + } + } + + fn close(&self, socket: Self::TcpSocket) -> Result<(), Self::Error> { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + internal_socket.close(); + + self.unused_handles.borrow_mut().push(socket).unwrap(); + Ok(()) + } +}