nal: init

This commit is contained in:
occheung 2020-12-04 15:50:37 +08:00
parent 5ca7c6b3ff
commit 552d21a1b3
6 changed files with 259 additions and 41 deletions

View File

@ -93,6 +93,13 @@ optional = true
version = "0.3.1" version = "0.3.1"
optional = true optional = true
# Support old version of embedded_nal interface only
# It is to operate with crates such as MiniMQ, which still depends on version 0.1.0
[dependencies.embedded-nal]
version = "0.1.0"
optional = true
[features] [features]
default = [] default = []
std = [ "rand", "hex-literal", "simple_logger", "rsa/default" ] std = [ "rand", "hex-literal", "simple_logger", "rsa/default" ]
nal_stack = [ "embedded-nal" ]

View File

@ -13,6 +13,9 @@ pub mod fake_rng;
pub mod oid; pub mod oid;
pub mod set; pub mod set;
#[cfg(feature = "nal_stack")]
pub mod tcp_stack;
// TODO: Implement errors // TODO: Implement errors
// Details: Encapsulate smoltcp & nom errors // Details: Encapsulate smoltcp & nom errors
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -49,8 +52,10 @@ use net::phy::Device;
use crate::set::TlsSocketSet; use crate::set::TlsSocketSet;
// One-call function for polling all sockets within socket set // One-call function for polling all sockets within socket set
// Input of vanilla sockets are optional, as one may not feel needed to create them
// TLS socket set is mandatory, otherwise you should just use `EthernetInterface::poll(..)`
pub fn poll<DeviceT>( pub fn poll<DeviceT>(
sockets: &mut SocketSet, sockets: Option<&mut SocketSet>,
tls_sockets: &mut TlsSocketSet, tls_sockets: &mut TlsSocketSet,
iface: &mut EthernetInterface<DeviceT>, iface: &mut EthernetInterface<DeviceT>,
now: Instant now: Instant
@ -58,6 +63,11 @@ pub fn poll<DeviceT>(
where where
DeviceT: for<'d> Device<'d> DeviceT: for<'d> Device<'d>
{ {
tls_sockets.polled_by(sockets, iface, now)?; tls_sockets.polled_by(iface, now)?;
iface.poll(sockets, now).map_err(Error::PropagatedError)
if let Some(vanilla_sockets) = sockets {
iface.poll(vanilla_sockets, now).map_err(Error::PropagatedError)?;
}
Ok(true)
} }

View File

@ -75,7 +75,7 @@ fn main() {
}; };
tls_socket.connect( tls_socket.connect(
&mut sockets, // &mut sockets,
(Ipv4Address::new(192, 168, 1, 125), 1883), (Ipv4Address::new(192, 168, 1, 125), 1883),
49600 49600
).unwrap(); ).unwrap();

View File

@ -14,6 +14,12 @@ pub struct TlsSocketSet<'a, 'b, 'c> {
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct TlsSocketHandle(usize); pub struct TlsSocketHandle(usize);
impl TlsSocketHandle {
pub(crate) fn new(index: usize) -> Self {
Self(index)
}
}
impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> { impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> {
pub fn new<T>(tls_sockets: T) -> Self pub fn new<T>(tls_sockets: T) -> Self
where where
@ -50,24 +56,29 @@ impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> {
self.tls_sockets[handle.0].as_mut().unwrap() self.tls_sockets[handle.0].as_mut().unwrap()
} }
pub fn len(&self) -> usize {
self.tls_sockets.len()
}
pub(crate) fn polled_by<DeviceT>( pub(crate) fn polled_by<DeviceT>(
&mut self, &mut self,
sockets: &mut SocketSet,
iface: &mut EthernetInterface<DeviceT>, iface: &mut EthernetInterface<DeviceT>,
now: Instant now: Instant
) -> smoltcp::Result<bool> ) -> smoltcp::Result<bool>
where where
DeviceT: for<'d> Device<'d> DeviceT: for<'d> Device<'d>
{ {
let mut changed = false;
for socket in self.tls_sockets.iter_mut() { for socket in self.tls_sockets.iter_mut() {
if socket.is_some() { if socket.is_some() {
socket.as_mut() if socket.as_mut().unwrap().update_handshake(iface, now)?
.unwrap() {
.update_handshake(iface, now)?; changed = true;
}
} }
} }
Ok(true) Ok(changed)
} }
} }

171
src/tcp_stack.rs Normal file
View File

@ -0,0 +1,171 @@
use embedded_nal as nal;
use smoltcp as net;
use crate::set::TlsSocketHandle as SocketHandle;
use crate::set::TlsSocketSet as SocketSet;
use crate::tls::TlsSocket;
use nal::{TcpStack, Mode, SocketAddr, nb};
use net::Error;
use net::iface::EthernetInterface;
use net::time::Instant;
use net::phy::Device;
use heapless::{Vec, consts::*};
use core::cell::RefCell;
#[derive(Debug)]
pub enum NetworkError {
NoSocket,
ConnectionFailure,
ReadFailure,
WriteFailure,
}
// Structure for implementaion TcpStack interface
pub struct NetworkStack<'a, 'b, 'c> {
sockets: RefCell<SocketSet<'a, 'b, 'c>>,
next_port: RefCell<u16>,
unused_handles: RefCell<Vec<SocketHandle, U16>>
}
impl<'a, 'b, 'c> NetworkStack<'a, 'b, 'c> {
pub fn new(sockets: SocketSet<'a, 'b, 'c>) -> Self {
let mut vec = Vec::new();
log::info!("socket set size: {:?}", sockets.len());
for index in 0..sockets.len() {
vec.push(
SocketHandle::new(index)
).unwrap();
}
Self {
sockets: RefCell::new(sockets),
next_port: RefCell::new(49152),
unused_handles: RefCell::new(vec)
}
}
fn get_ephemeral_port(&self) -> u16 {
// Get the next ephemeral port
let current_port = self.next_port.borrow().clone();
let (next, wrap) = self.next_port.borrow().overflowing_add(1);
*self.next_port.borrow_mut() = if wrap { 49152 } else { next };
return current_port;
}
pub fn poll<DeviceT>(
&self,
iface: &mut EthernetInterface<DeviceT>,
now: Instant,
) -> bool
where
DeviceT: for <'d> Device<'d>
{
let mut sockets = self.sockets.borrow_mut();
sockets.polled_by(iface, now).map_or(false, |updated| updated)
}
}
impl<'a, 'b, 'c> TcpStack for NetworkStack<'a, 'b, 'c> {
type TcpSocket = SocketHandle;
type Error = NetworkError;
fn open(&self, _: Mode) -> Result<Self::TcpSocket, Self::Error> {
match self.unused_handles.borrow_mut().pop() {
Some(handle) => {
// Abort any active connections on the handle.
log::info!("Have handle");
let mut sockets = self.sockets.borrow_mut();
let mut internal_socket = sockets.get(handle);
internal_socket.close();
Ok(handle)
}
None => {
log::info!("Insufficient handles");
Err(NetworkError::NoSocket)
},
}
}
fn connect(
&self,
socket: Self::TcpSocket,
remote: SocketAddr
) -> Result<Self::TcpSocket, Self::Error> {
let mut sockets = self.sockets.borrow_mut();
let internal_socket = sockets.get(socket);
match remote.ip() {
embedded_nal::IpAddr::V4(addr) => {
let address = {
let octets = addr.octets();
net::wire::Ipv4Address::new(octets[0], octets[1], octets[2], octets[3])
};
internal_socket
.connect((address, remote.port()), self.get_ephemeral_port())
.map_err(|_| NetworkError::ConnectionFailure)?;
}
embedded_nal::IpAddr::V6(addr) => {
let address = {
let octets = addr.segments();
net::wire::Ipv6Address::new(
octets[0], octets[1], octets[2], octets[3], octets[4], octets[5],
octets[6], octets[7],
)
};
internal_socket
.connect((address, remote.port()), self.get_ephemeral_port())
.map_err(|_| NetworkError::ConnectionFailure)?;
}
};
Ok(socket)
}
fn is_connected(
&self,
socket: &Self::TcpSocket
) -> Result<bool, Self::Error> {
let mut sockets = self.sockets.borrow_mut();
let internal_socket = sockets.get(*socket);
Ok(internal_socket.is_connected().unwrap())
}
fn write(
&self,
socket: &mut Self::TcpSocket,
buffer: &[u8]
) -> nb::Result<usize, Self::Error> {
let mut sockets = self.sockets.borrow_mut();
let internal_socket = sockets.get(*socket);
internal_socket.send_slice(buffer)
.map_err(|_| nb::Error::Other(NetworkError::WriteFailure))
}
fn read(
&self,
socket: &mut Self::TcpSocket,
buffer: &mut [u8]
) -> nb::Result<usize, Self::Error> {
let mut sockets = self.sockets.borrow_mut();
let internal_socket = sockets.get(*socket);
internal_socket.recv_slice(buffer)
.map_err(|_| nb::Error::Other(NetworkError::ReadFailure))
}
fn close(
&self,
socket: Self::TcpSocket
) -> Result<(), Self::Error> {
let mut sockets = self.sockets.borrow_mut();
let internal_socket = sockets.get(socket);
internal_socket.close();
self.unused_handles.borrow_mut().push(socket).unwrap();
Ok(())
}
}

View File

@ -101,16 +101,28 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
T: Into<IpEndpoint>, T: Into<IpEndpoint>,
U: Into<IpEndpoint>, U: Into<IpEndpoint>,
{ {
// Start TCP handshake
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.connect(remote_endpoint, local_endpoint)?;
// Permit TLS handshake as well
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
session.connect(
tcp_socket.remote_endpoint(), // Start TCP handshake
tcp_socket.local_endpoint() if !tcp_socket.is_open() {
); tcp_socket.connect(remote_endpoint, local_endpoint)?;
// Start TLS handshake if TCP handshake will commence
session.connect(
tcp_socket.remote_endpoint(),
tcp_socket.local_endpoint()
);
} else {
// Also start TLS handshake if for some reason TCP is ready,
// and TLS is idle
if session.get_tls_state() == TlsState::DEFAULT {
session.connect(
tcp_socket.remote_endpoint(),
tcp_socket.local_endpoint()
);
}
}
Ok(()) Ok(())
} }
@ -141,7 +153,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
DeviceT: for<'d> Device<'d> DeviceT: for<'d> Device<'d>
{ {
// Poll the TCP socket, no matter what // Poll the TCP socket, no matter what
iface.poll(&mut self.sockets, now)?; let propagated_poll = iface.poll(&mut self.sockets, now)?;
// Handle TLS handshake through TLS states // Handle TLS handshake through TLS states
let tls_state = { let tls_state = {
@ -159,7 +171,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// Close TCP socket if necessary // Close TCP socket if necessary
if tcp_state == TcpState::Established && tls_state == TlsState::DEFAULT { if tcp_state == TcpState::Established && tls_state == TlsState::DEFAULT {
self.sockets.get::<TcpSocket>(self.tcp_handle).close(); self.sockets.get::<TcpSocket>(self.tcp_handle).close();
return Ok(false); return Ok(propagated_poll);
} }
// Skip handshake processing if it is already completed // Skip handshake processing if it is already completed
@ -196,12 +208,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
_ => { _ => {
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
session.reset_state(); session.reset_state();
log::info!("TLS socket resets after TCP socket closed") log::info!("TLS socket resets after TCP socket closed");
} }
} }
// Terminate the procedure, as no processing is necessary // Terminate the procedure, as no processing is necessary
return Ok(false); return Ok(propagated_poll);
} }
} }
@ -312,7 +324,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// `close()` to the TCP socket // `close()` to the TCP socket
self.session.borrow_mut().reset_state(); self.session.borrow_mut().reset_state();
return Ok(false); return Ok(propagated_poll);
} }
// Handle TLS handshake through TLS states // Handle TLS handshake through TLS states
@ -555,7 +567,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// There is no need to care about handshake if it was completed // There is no need to care about handshake if it was completed
TlsState::CLIENT_CONNECTED => { TlsState::CLIENT_CONNECTED => {
return Ok(true); return Ok(propagated_poll);
} }
// This state waits for Client Hello handshake from a client // This state waits for Client Hello handshake from a client
@ -793,7 +805,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// There is no need to care about handshake if it was completed // There is no need to care about handshake if it was completed
// This is to prevent accidental dequeing of application data // This is to prevent accidental dequeing of application data
TlsState::SERVER_CONNECTED => { TlsState::SERVER_CONNECTED => {
return Ok(true); return Ok(propagated_poll);
} }
// Other states // Other states
@ -811,14 +823,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// Check if there are bytes enqueued in the recv buffer // Check if there are bytes enqueued in the recv buffer
// No need to do further dequeuing if there are no receivable bytes // No need to do further dequeuing if there are no receivable bytes
if !tcp_socket.can_recv() { if !tcp_socket.can_recv() {
return Ok(self.session.borrow().has_completed_handshake()) return Ok(propagated_poll)
} }
// Peak into the first 5 bytes (TLS record layer) // Peak into the first 5 bytes (TLS record layer)
// This tells the length of the entire record // This tells the length of the entire record
let length = match tcp_socket.peek(5) { let length = match tcp_socket.peek(5) {
Ok(bytes) => NetworkEndian::read_u16(&bytes[3..5]), Ok(bytes) => NetworkEndian::read_u16(&bytes[3..5]),
_ => return Ok(self.session.borrow().has_completed_handshake()) _ => return Ok(propagated_poll)
}; };
// Recv the entire TLS record // Recv the entire TLS record
@ -830,7 +842,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// Parse the bytes representation of a TLS record // Parse the bytes representation of a TLS record
let (repr_slice, mut repr) = match parse_tls_repr(&tls_repr_vec) { let (repr_slice, mut repr) = match parse_tls_repr(&tls_repr_vec) {
Ok((_, (repr_slice, repr))) => (repr_slice, repr), Ok((_, (repr_slice, repr))) => (repr_slice, repr),
_ => return Ok(self.session.borrow().has_completed_handshake()) _ => return Ok(propagated_poll)
}; };
// Process record base on content type // Process record base on content type
@ -888,16 +900,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
AlertType::UnexpectedMessage, AlertType::UnexpectedMessage,
&inner_plaintext[..content_type_index] &inner_plaintext[..content_type_index]
); );
return Ok(false); return Ok(propagated_poll);
}, },
TlsContentType::Alert => { TlsContentType::Alert => {
self.session.borrow_mut().reset_state(); self.session.borrow_mut().reset_state();
return Ok(false); return Ok(propagated_poll);
}, },
TlsContentType::ApplicationData => { TlsContentType::ApplicationData => {
return Ok( return Ok(propagated_poll);
self.session.borrow().has_completed_handshake()
);
}, },
_ => () _ => ()
} }
@ -920,7 +930,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
handshake: Some(handshake_repr) handshake: Some(handshake_repr)
} }
).is_err() { ).is_err() {
return Ok(self.session.borrow().has_completed_handshake()) return Ok(propagated_poll)
} }
} }
}, },
@ -928,7 +938,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
TlsContentType::ChangeCipherSpec | TlsContentType::ChangeCipherSpec |
TlsContentType::Handshake => { TlsContentType::Handshake => {
if self.process(repr_slice, repr).is_err() { if self.process(repr_slice, repr).is_err() {
return Ok(self.session.borrow().has_completed_handshake()) return Ok(propagated_poll)
} }
log::info!("Processed record"); log::info!("Processed record");
}, },
@ -948,7 +958,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
} }
} }
Ok(self.session.borrow().has_completed_handshake()) Ok(propagated_poll)
} }
// Process TLS ingress during handshake // Process TLS ingress during handshake
@ -1856,15 +1866,16 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
Ok(actual_application_data_length) Ok(actual_application_data_length)
} }
pub fn send_slice(&mut self, data: &[u8]) -> Result<()> { pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
// If the handshake is not completed, do not push bytes onto the buffer // If the handshake is not completed, do not push bytes onto the buffer
// through TlsSocket.send_slice() // through TlsSocket.send_slice()
// Handshake send should be through TCPSocket directly. // Handshake send should be through TCPSocket directly.
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
if session.get_tls_state() != TlsState::CLIENT_CONNECTED && if session.get_tls_state() != TlsState::CLIENT_CONNECTED &&
session.get_tls_state() != TlsState::SERVER_CONNECTED { session.get_tls_state() != TlsState::SERVER_CONNECTED {
return Ok(()); return Ok(0);
} }
let data_length = data.len();
// Sending order: // Sending order:
// 1. Associated data/ TLS Record layer // 1. Associated data/ TLS Record layer
@ -1877,7 +1888,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
]; ];
NetworkEndian::write_u16(&mut associated_data[3..5], NetworkEndian::write_u16(&mut associated_data[3..5],
u16::try_from(data.len()).unwrap() // Payload length u16::try_from(data_length).unwrap() // Payload length
+ 1 // Content type length + 1 // Content type length
+ 16 // Auth tag length + 16 // Auth tag length
); );
@ -1900,7 +1911,15 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
tcp_socket.send_slice(&vec)?; tcp_socket.send_slice(&vec)?;
tcp_socket.send_slice(&tag)?; tcp_socket.send_slice(&tag)?;
Ok(()) Ok(data_length)
}
pub fn is_connected(&self) -> Result<bool> {
let session = self.session.borrow();
Ok(
session.get_tls_state() == TlsState::CLIENT_CONNECTED ||
session.get_tls_state() == TlsState::SERVER_CONNECTED
)
} }
// Send `Close notify` alert to remote side // Send `Close notify` alert to remote side
@ -1936,7 +1955,7 @@ use core::fmt;
impl<'a, 'b, 'c> fmt::Write for TlsSocket<'a, 'b, 'c> { impl<'a, 'b, 'c> fmt::Write for TlsSocket<'a, 'b, 'c> {
fn write_str(&mut self, slice: &str) -> fmt::Result { fn write_str(&mut self, slice: &str) -> fmt::Result {
let slice = slice.as_bytes(); let slice = slice.as_bytes();
if self.send_slice(slice) == Ok(()) { if self.send_slice(slice) == Ok(slice.len()) {
Ok(()) Ok(())
} else { } else {
Err(fmt::Error) Err(fmt::Error)