From 552d21a1b398ce775857f9fd8e060d3a297152d2 Mon Sep 17 00:00:00 2001 From: occheung Date: Fri, 4 Dec 2020 15:50:37 +0800 Subject: [PATCH] nal: init --- Cargo.toml | 7 ++ src/lib.rs | 16 ++++- src/main.rs | 2 +- src/set.rs | 21 ++++-- src/tcp_stack.rs | 171 +++++++++++++++++++++++++++++++++++++++++++++++ src/tls.rs | 83 ++++++++++++++--------- 6 files changed, 259 insertions(+), 41 deletions(-) create mode 100644 src/tcp_stack.rs diff --git a/Cargo.toml b/Cargo.toml index f847384..7656bd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,6 +93,13 @@ optional = true version = "0.3.1" optional = true +# Support old version of embedded_nal interface only +# It is to operate with crates such as MiniMQ, which still depends on version 0.1.0 +[dependencies.embedded-nal] +version = "0.1.0" +optional = true + [features] default = [] std = [ "rand", "hex-literal", "simple_logger", "rsa/default" ] +nal_stack = [ "embedded-nal" ] diff --git a/src/lib.rs b/src/lib.rs index 8bece10..cbeee54 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,9 @@ pub mod fake_rng; pub mod oid; pub mod set; +#[cfg(feature = "nal_stack")] +pub mod tcp_stack; + // TODO: Implement errors // Details: Encapsulate smoltcp & nom errors #[derive(Debug, Clone)] @@ -49,8 +52,10 @@ use net::phy::Device; use crate::set::TlsSocketSet; // One-call function for polling all sockets within socket set +// Input of vanilla sockets are optional, as one may not feel needed to create them +// TLS socket set is mandatory, otherwise you should just use `EthernetInterface::poll(..)` pub fn poll( - sockets: &mut SocketSet, + sockets: Option<&mut SocketSet>, tls_sockets: &mut TlsSocketSet, iface: &mut EthernetInterface, now: Instant @@ -58,6 +63,11 @@ pub fn poll( where DeviceT: for<'d> Device<'d> { - tls_sockets.polled_by(sockets, iface, now)?; - iface.poll(sockets, now).map_err(Error::PropagatedError) + tls_sockets.polled_by(iface, now)?; + + if let Some(vanilla_sockets) = sockets { + iface.poll(vanilla_sockets, now).map_err(Error::PropagatedError)?; + } + + Ok(true) } diff --git a/src/main.rs b/src/main.rs index ed38fec..5e3769f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -75,7 +75,7 @@ fn main() { }; tls_socket.connect( - &mut sockets, + // &mut sockets, (Ipv4Address::new(192, 168, 1, 125), 1883), 49600 ).unwrap(); diff --git a/src/set.rs b/src/set.rs index f572d22..d80c3dd 100644 --- a/src/set.rs +++ b/src/set.rs @@ -14,6 +14,12 @@ pub struct TlsSocketSet<'a, 'b, 'c> { #[derive(Clone, Copy, Debug)] pub struct TlsSocketHandle(usize); +impl TlsSocketHandle { + pub(crate) fn new(index: usize) -> Self { + Self(index) + } +} + impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> { pub fn new(tls_sockets: T) -> Self where @@ -50,24 +56,29 @@ impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> { self.tls_sockets[handle.0].as_mut().unwrap() } + pub fn len(&self) -> usize { + self.tls_sockets.len() + } + pub(crate) fn polled_by( &mut self, - sockets: &mut SocketSet, iface: &mut EthernetInterface, now: Instant ) -> smoltcp::Result where DeviceT: for<'d> Device<'d> { + let mut changed = false; for socket in self.tls_sockets.iter_mut() { if socket.is_some() { - socket.as_mut() - .unwrap() - .update_handshake(iface, now)?; + if socket.as_mut().unwrap().update_handshake(iface, now)? + { + changed = true; + } } } - Ok(true) + Ok(changed) } } diff --git a/src/tcp_stack.rs b/src/tcp_stack.rs new file mode 100644 index 0000000..387b82a --- /dev/null +++ b/src/tcp_stack.rs @@ -0,0 +1,171 @@ +use embedded_nal as nal; +use smoltcp as net; + +use crate::set::TlsSocketHandle as SocketHandle; +use crate::set::TlsSocketSet as SocketSet; +use crate::tls::TlsSocket; + +use nal::{TcpStack, Mode, SocketAddr, nb}; +use net::Error; +use net::iface::EthernetInterface; +use net::time::Instant; +use net::phy::Device; +use heapless::{Vec, consts::*}; + +use core::cell::RefCell; + +#[derive(Debug)] +pub enum NetworkError { + NoSocket, + ConnectionFailure, + ReadFailure, + WriteFailure, +} + +// Structure for implementaion TcpStack interface +pub struct NetworkStack<'a, 'b, 'c> { + sockets: RefCell>, + next_port: RefCell, + unused_handles: RefCell> +} + +impl<'a, 'b, 'c> NetworkStack<'a, 'b, 'c> { + pub fn new(sockets: SocketSet<'a, 'b, 'c>) -> Self { + let mut vec = Vec::new(); + log::info!("socket set size: {:?}", sockets.len()); + for index in 0..sockets.len() { + vec.push( + SocketHandle::new(index) + ).unwrap(); + } + + Self { + sockets: RefCell::new(sockets), + next_port: RefCell::new(49152), + unused_handles: RefCell::new(vec) + } + } + + fn get_ephemeral_port(&self) -> u16 { + // Get the next ephemeral port + let current_port = self.next_port.borrow().clone(); + + let (next, wrap) = self.next_port.borrow().overflowing_add(1); + *self.next_port.borrow_mut() = if wrap { 49152 } else { next }; + + return current_port; + } + + pub fn poll( + &self, + iface: &mut EthernetInterface, + now: Instant, + ) -> bool + where + DeviceT: for <'d> Device<'d> + { + let mut sockets = self.sockets.borrow_mut(); + sockets.polled_by(iface, now).map_or(false, |updated| updated) + } +} + +impl<'a, 'b, 'c> TcpStack for NetworkStack<'a, 'b, 'c> { + type TcpSocket = SocketHandle; + type Error = NetworkError; + + fn open(&self, _: Mode) -> Result { + match self.unused_handles.borrow_mut().pop() { + Some(handle) => { + // Abort any active connections on the handle. + log::info!("Have handle"); + let mut sockets = self.sockets.borrow_mut(); + let mut internal_socket = sockets.get(handle); + internal_socket.close(); + + Ok(handle) + } + None => { + log::info!("Insufficient handles"); + Err(NetworkError::NoSocket) + }, + } + } + + fn connect( + &self, + socket: Self::TcpSocket, + remote: SocketAddr + ) -> Result { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket = sockets.get(socket); + + match remote.ip() { + embedded_nal::IpAddr::V4(addr) => { + let address = { + let octets = addr.octets(); + net::wire::Ipv4Address::new(octets[0], octets[1], octets[2], octets[3]) + }; + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + } + embedded_nal::IpAddr::V6(addr) => { + let address = { + let octets = addr.segments(); + net::wire::Ipv6Address::new( + octets[0], octets[1], octets[2], octets[3], octets[4], octets[5], + octets[6], octets[7], + ) + }; + internal_socket + .connect((address, remote.port()), self.get_ephemeral_port()) + .map_err(|_| NetworkError::ConnectionFailure)?; + } + }; + + Ok(socket) + } + + fn is_connected( + &self, + socket: &Self::TcpSocket + ) -> Result { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket = sockets.get(*socket); + Ok(internal_socket.is_connected().unwrap()) + } + + fn write( + &self, + socket: &mut Self::TcpSocket, + buffer: &[u8] + ) -> nb::Result { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket = sockets.get(*socket); + internal_socket.send_slice(buffer) + .map_err(|_| nb::Error::Other(NetworkError::WriteFailure)) + } + + fn read( + &self, + socket: &mut Self::TcpSocket, + buffer: &mut [u8] + ) -> nb::Result { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket = sockets.get(*socket); + internal_socket.recv_slice(buffer) + .map_err(|_| nb::Error::Other(NetworkError::ReadFailure)) + } + + fn close( + &self, + socket: Self::TcpSocket + ) -> Result<(), Self::Error> { + let mut sockets = self.sockets.borrow_mut(); + let internal_socket = sockets.get(socket); + internal_socket.close(); + + self.unused_handles.borrow_mut().push(socket).unwrap(); + Ok(()) + } +} \ No newline at end of file diff --git a/src/tls.rs b/src/tls.rs index 84a7bca..07faab1 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -101,16 +101,28 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { T: Into, U: Into, { - // Start TCP handshake let mut tcp_socket = self.sockets.get::(self.tcp_handle); - tcp_socket.connect(remote_endpoint, local_endpoint)?; - - // Permit TLS handshake as well let mut session = self.session.borrow_mut(); - session.connect( - tcp_socket.remote_endpoint(), - tcp_socket.local_endpoint() - ); + + // Start TCP handshake + if !tcp_socket.is_open() { + tcp_socket.connect(remote_endpoint, local_endpoint)?; + // Start TLS handshake if TCP handshake will commence + session.connect( + tcp_socket.remote_endpoint(), + tcp_socket.local_endpoint() + ); + } else { + // Also start TLS handshake if for some reason TCP is ready, + // and TLS is idle + if session.get_tls_state() == TlsState::DEFAULT { + session.connect( + tcp_socket.remote_endpoint(), + tcp_socket.local_endpoint() + ); + } + } + Ok(()) } @@ -141,7 +153,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { DeviceT: for<'d> Device<'d> { // Poll the TCP socket, no matter what - iface.poll(&mut self.sockets, now)?; + let propagated_poll = iface.poll(&mut self.sockets, now)?; // Handle TLS handshake through TLS states let tls_state = { @@ -159,7 +171,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Close TCP socket if necessary if tcp_state == TcpState::Established && tls_state == TlsState::DEFAULT { self.sockets.get::(self.tcp_handle).close(); - return Ok(false); + return Ok(propagated_poll); } // Skip handshake processing if it is already completed @@ -187,7 +199,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { session.get_remote_endpoint(), session.get_local_endpoint() )?; - } + } } // For any other functioning state, the TCP connection being not @@ -196,12 +208,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { _ => { let mut session = self.session.borrow_mut(); session.reset_state(); - log::info!("TLS socket resets after TCP socket closed") + log::info!("TLS socket resets after TCP socket closed"); } } // Terminate the procedure, as no processing is necessary - return Ok(false); + return Ok(propagated_poll); } } @@ -312,7 +324,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // `close()` to the TCP socket self.session.borrow_mut().reset_state(); - return Ok(false); + return Ok(propagated_poll); } // Handle TLS handshake through TLS states @@ -555,7 +567,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // There is no need to care about handshake if it was completed TlsState::CLIENT_CONNECTED => { - return Ok(true); + return Ok(propagated_poll); } // This state waits for Client Hello handshake from a client @@ -793,7 +805,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // There is no need to care about handshake if it was completed // This is to prevent accidental dequeing of application data TlsState::SERVER_CONNECTED => { - return Ok(true); + return Ok(propagated_poll); } // Other states @@ -811,14 +823,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Check if there are bytes enqueued in the recv buffer // No need to do further dequeuing if there are no receivable bytes if !tcp_socket.can_recv() { - return Ok(self.session.borrow().has_completed_handshake()) + return Ok(propagated_poll) } // Peak into the first 5 bytes (TLS record layer) // This tells the length of the entire record let length = match tcp_socket.peek(5) { Ok(bytes) => NetworkEndian::read_u16(&bytes[3..5]), - _ => return Ok(self.session.borrow().has_completed_handshake()) + _ => return Ok(propagated_poll) }; // Recv the entire TLS record @@ -830,7 +842,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Parse the bytes representation of a TLS record let (repr_slice, mut repr) = match parse_tls_repr(&tls_repr_vec) { Ok((_, (repr_slice, repr))) => (repr_slice, repr), - _ => return Ok(self.session.borrow().has_completed_handshake()) + _ => return Ok(propagated_poll) }; // Process record base on content type @@ -888,16 +900,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { AlertType::UnexpectedMessage, &inner_plaintext[..content_type_index] ); - return Ok(false); + return Ok(propagated_poll); }, TlsContentType::Alert => { self.session.borrow_mut().reset_state(); - return Ok(false); + return Ok(propagated_poll); }, TlsContentType::ApplicationData => { - return Ok( - self.session.borrow().has_completed_handshake() - ); + return Ok(propagated_poll); }, _ => () } @@ -920,7 +930,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { handshake: Some(handshake_repr) } ).is_err() { - return Ok(self.session.borrow().has_completed_handshake()) + return Ok(propagated_poll) } } }, @@ -928,7 +938,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { TlsContentType::ChangeCipherSpec | TlsContentType::Handshake => { if self.process(repr_slice, repr).is_err() { - return Ok(self.session.borrow().has_completed_handshake()) + return Ok(propagated_poll) } log::info!("Processed record"); }, @@ -948,7 +958,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } } - Ok(self.session.borrow().has_completed_handshake()) + Ok(propagated_poll) } // Process TLS ingress during handshake @@ -1856,15 +1866,16 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { Ok(actual_application_data_length) } - pub fn send_slice(&mut self, data: &[u8]) -> Result<()> { + pub fn send_slice(&mut self, data: &[u8]) -> Result { // If the handshake is not completed, do not push bytes onto the buffer // through TlsSocket.send_slice() // Handshake send should be through TCPSocket directly. let mut session = self.session.borrow_mut(); if session.get_tls_state() != TlsState::CLIENT_CONNECTED && session.get_tls_state() != TlsState::SERVER_CONNECTED { - return Ok(()); + return Ok(0); } + let data_length = data.len(); // Sending order: // 1. Associated data/ TLS Record layer @@ -1877,7 +1888,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { ]; NetworkEndian::write_u16(&mut associated_data[3..5], - u16::try_from(data.len()).unwrap() // Payload length + u16::try_from(data_length).unwrap() // Payload length + 1 // Content type length + 16 // Auth tag length ); @@ -1900,7 +1911,15 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { tcp_socket.send_slice(&vec)?; tcp_socket.send_slice(&tag)?; - Ok(()) + Ok(data_length) + } + + pub fn is_connected(&self) -> Result { + let session = self.session.borrow(); + Ok( + session.get_tls_state() == TlsState::CLIENT_CONNECTED || + session.get_tls_state() == TlsState::SERVER_CONNECTED + ) } // Send `Close notify` alert to remote side @@ -1936,7 +1955,7 @@ use core::fmt; impl<'a, 'b, 'c> fmt::Write for TlsSocket<'a, 'b, 'c> { fn write_str(&mut self, slice: &str) -> fmt::Result { let slice = slice.as_bytes(); - if self.send_slice(slice) == Ok(()) { + if self.send_slice(slice) == Ok(slice.len()) { Ok(()) } else { Err(fmt::Error)