diff --git a/Cargo.toml b/Cargo.toml index f45a4ec..ee97bc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,9 @@ cortex-m-rt = { version = "0.6", optional = true } cortex-m-rtic = { version = "0.5.3", optional = true } panic-itm = { version = "0.4", optional = true } log = { version = "0.4", optional = true } +embedded-time = { version = "0.10.1", optional = true } +embedded-nal = { version = "0.1.0", optional = true } +heapless = { version = "0.5.6", optional = true } [features] smoltcp-phy = ["smoltcp"] @@ -31,6 +34,7 @@ smoltcp-phy-all = [ # Example-based features tx_stm32f407 = ["stm32f4xx-hal/stm32f407", "cortex-m", "cortex-m-rtic", "panic-itm", "log"] tcp_stm32f407 = ["stm32f4xx-hal/stm32f407", "cortex-m", "cortex-m-rt", "cortex-m-rtic", "smoltcp-phy-all", "smoltcp/log", "panic-itm", "log"] +nal = [ "embedded-time", "embedded-nal", "smoltcp-phy", "heapless" ] default = [] [[example]] diff --git a/src/lib.rs b/src/lib.rs index 858409d..69bce58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,9 @@ pub mod tx; #[cfg(feature="smoltcp")] pub mod smoltcp_phy; +#[cfg(feature="nal")] +pub mod nal; + /// Max raw frame array size pub const RAW_FRAME_LENGTH_MAX: usize = 1518; diff --git a/src/nal.rs b/src/nal.rs new file mode 100644 index 0000000..0cb71da --- /dev/null +++ b/src/nal.rs @@ -0,0 +1,245 @@ +use core::cell::RefCell; +use core::convert::TryFrom; +use heapless::{consts, Vec}; +use embedded_nal as nal; +use nal::nb; +use smoltcp as net; +use embedded_hal::{ + blocking::spi::Transfer, + digital::v2::OutputPin +}; +pub use embedded_time as time; +use time::duration::*; + +#[derive(Debug)] +pub enum NetworkError { + NoSocket, + ConnectionFailure, + ReadFailure, + WriteFailure, + Unsupported, + TimeFault, +} + +pub type NetworkInterface = net::iface::EthernetInterface< + 'static, + crate::smoltcp_phy::SmoltcpDevice< + crate::SpiEth + >, +>; + +pub struct NetworkStack<'a, SPI, NSS, IntClock> +where + SPI: 'static + Transfer, + NSS: 'static + OutputPin, + IntClock: time::Clock, +{ + network_interface: RefCell>, + sockets: RefCell>, + next_port: RefCell, + unused_handles: RefCell>, + time_ms: RefCell, + last_update_instant: RefCell>>, + clock: IntClock +} + +impl<'a, SPI, NSS, IntClock> NetworkStack<'a, SPI, NSS, IntClock> +where + SPI: Transfer, + NSS: OutputPin, + IntClock: time::Clock, +{ + pub fn new(interface: NetworkInterface, sockets: net::socket::SocketSet<'a>, clock: IntClock) -> Self { + let mut unused_handles: Vec = Vec::new(); + for socket in sockets.iter() { + unused_handles.push(socket.handle()).unwrap(); + } + NetworkStack { + network_interface: RefCell::new(interface), + sockets: RefCell::new(sockets), + next_port: RefCell::new(49152), + unused_handles: RefCell::new(unused_handles), + time_ms: RefCell::new(0), + last_update_instant: RefCell::new(None), + clock, + } + } + + // Include auto_time_update to allow Instant::now() to not be called + // Instant::now() is not safe to call in `init()` context + pub fn update(&self, auto_time_update: bool) -> Result { + if auto_time_update { + // Check if it is the first time the stack has updated the time itself + let now = match *self.last_update_instant.borrow() { + // If it is the first time, do not advance time + // Simply store the current instant to initiate time updating + None => self.clock.try_now().map_err(|_| NetworkError::TimeFault)?, + // If it was updated before, advance time and update last_update_instant + Some(instant) => { + // Calculate elapsed time + let now = self.clock.try_now().map_err(|_| NetworkError::TimeFault)?; + let duration = now.checked_duration_since(&instant).ok_or(NetworkError::TimeFault)?; + let duration_ms = time::duration::Milliseconds::::try_from(duration).map_err(|_| NetworkError::TimeFault)?; + // Adjust duration into ms (note: decimal point truncated) + self.advance_time(*duration_ms.integer()); + now + } + }; + self.last_update_instant.replace(Some(now)); + } + match self.network_interface.borrow_mut().poll( + &mut self.sockets.borrow_mut(), + net::time::Instant::from_millis(*self.time_ms.borrow() as i64), + ) { + Ok(changed) => Ok(!changed), + Err(_e) => { + Ok(true) + } + } + } + + pub fn advance_time(&self, duration: u32) { + let time = self.time_ms.try_borrow().unwrap().wrapping_add(duration); + self.time_ms.replace(time); + } + + fn get_ephemeral_port(&self) -> u16 { + // Get the next ephemeral port + let current_port = self.next_port.borrow().clone(); + let (next, wrap) = self.next_port.borrow().overflowing_add(1); + *self.next_port.borrow_mut() = if wrap { 49152 } else { next }; + return current_port; + } +} +impl<'a, SPI, NSS, IntClock> nal::TcpStack for NetworkStack<'a, SPI, NSS, IntClock> +where + SPI: Transfer, + NSS: OutputPin, + IntClock: time::Clock, +{ + type TcpSocket = net::socket::SocketHandle; + type Error = NetworkError; + fn open(&self, _mode: nal::Mode) -> Result { + match self.unused_handles.borrow_mut().pop() { + Some(handle) => { + // Abort any active connections on the handle. + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(handle); + internal_socket.abort(); + Ok(handle) + } + None => Err(NetworkError::NoSocket), + } + } + + // Ideally connect is only to be performed in `init()` of `main.rs` + // Calling `Instant::now()` of `rtic::cyccnt` would face correctness issue during `init()` + fn connect( + &self, + socket: Self::TcpSocket, + remote: nal::SocketAddr, + ) -> Result { + let address = { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + // If we're already in the process of connecting, ignore the request silently. + if internal_socket.is_open() { + return Ok(socket); + } + match remote.ip() { + nal::IpAddr::V4(addr) => { + let address = + net::wire::Ipv4Address::from_bytes(&addr.octets()[..]); + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + net::wire::IpAddress::Ipv4(address) + } + nal::IpAddr::V6(addr) => { + let address = net::wire::Ipv6Address::from_parts(&addr.segments()[..]); + internal_socket.connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + net::wire::IpAddress::Ipv6(address) + } + } + }; + // Blocking connect + loop { + match self.is_connected(&socket) { + Ok(true) => break, + _ => { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + // If the connect got ACK->RST, it will end up in Closed TCP state + // Perform reconnection in this case + if internal_socket.state() == net::socket::TcpState::Closed { + internal_socket.close(); + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + } + } + } + // Avoid using Instant::now() and Advance time manually + self.update(false)?; + { + self.advance_time(1); + } + } + Ok(socket) + } + + fn is_connected(&self, socket: &Self::TcpSocket) -> Result { + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + Ok(socket.may_send() && socket.may_recv()) + } + + fn write(&self, socket: &mut Self::TcpSocket, buffer: &[u8]) -> nb::Result { + let mut non_queued_bytes = &buffer[..]; + while non_queued_bytes.len() != 0 { + let result = { + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + let result = socket.send_slice(non_queued_bytes); + result + }; + match result { + Ok(num_bytes) => { + // In case the buffer is filled up, push bytes into ethernet driver + if num_bytes != non_queued_bytes.len() { + self.update(true)?; + } + // Process the unwritten bytes again, if any + non_queued_bytes = &non_queued_bytes[num_bytes..] + } + Err(_) => return Err(nb::Error::Other(NetworkError::WriteFailure)), + } + } + Ok(buffer.len()) + } + + fn read( + &self, + socket: &mut Self::TcpSocket, + buffer: &mut [u8], + ) -> nb::Result { + // Enqueue received bytes into the TCP socket buffer + self.update(true)?; + let mut sockets = self.sockets.borrow_mut(); + let socket: &mut net::socket::TcpSocket = &mut *sockets.get(*socket); + let result = socket.recv_slice(buffer); + match result { + Ok(num_bytes) => Ok(num_bytes), + Err(_) => Err(nb::Error::Other(NetworkError::ReadFailure)), + } + } + + fn close(&self, socket: Self::TcpSocket) -> Result<(), Self::Error> { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket: &mut net::socket::TcpSocket = &mut *sockets.get(socket); + internal_socket.close(); + self.unused_handles.borrow_mut().push(socket).unwrap(); + Ok(()) + } +} \ No newline at end of file