nal: init
This commit is contained in:
parent
5ca7c6b3ff
commit
552d21a1b3
|
@ -93,6 +93,13 @@ optional = true
|
|||
version = "0.3.1"
|
||||
optional = true
|
||||
|
||||
# Support old version of embedded_nal interface only
|
||||
# It is to operate with crates such as MiniMQ, which still depends on version 0.1.0
|
||||
[dependencies.embedded-nal]
|
||||
version = "0.1.0"
|
||||
optional = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
std = [ "rand", "hex-literal", "simple_logger", "rsa/default" ]
|
||||
nal_stack = [ "embedded-nal" ]
|
||||
|
|
16
src/lib.rs
16
src/lib.rs
|
@ -13,6 +13,9 @@ pub mod fake_rng;
|
|||
pub mod oid;
|
||||
pub mod set;
|
||||
|
||||
#[cfg(feature = "nal_stack")]
|
||||
pub mod tcp_stack;
|
||||
|
||||
// TODO: Implement errors
|
||||
// Details: Encapsulate smoltcp & nom errors
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -49,8 +52,10 @@ use net::phy::Device;
|
|||
use crate::set::TlsSocketSet;
|
||||
|
||||
// One-call function for polling all sockets within socket set
|
||||
// Input of vanilla sockets are optional, as one may not feel needed to create them
|
||||
// TLS socket set is mandatory, otherwise you should just use `EthernetInterface::poll(..)`
|
||||
pub fn poll<DeviceT>(
|
||||
sockets: &mut SocketSet,
|
||||
sockets: Option<&mut SocketSet>,
|
||||
tls_sockets: &mut TlsSocketSet,
|
||||
iface: &mut EthernetInterface<DeviceT>,
|
||||
now: Instant
|
||||
|
@ -58,6 +63,11 @@ pub fn poll<DeviceT>(
|
|||
where
|
||||
DeviceT: for<'d> Device<'d>
|
||||
{
|
||||
tls_sockets.polled_by(sockets, iface, now)?;
|
||||
iface.poll(sockets, now).map_err(Error::PropagatedError)
|
||||
tls_sockets.polled_by(iface, now)?;
|
||||
|
||||
if let Some(vanilla_sockets) = sockets {
|
||||
iface.poll(vanilla_sockets, now).map_err(Error::PropagatedError)?;
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ fn main() {
|
|||
};
|
||||
|
||||
tls_socket.connect(
|
||||
&mut sockets,
|
||||
// &mut sockets,
|
||||
(Ipv4Address::new(192, 168, 1, 125), 1883),
|
||||
49600
|
||||
).unwrap();
|
||||
|
|
21
src/set.rs
21
src/set.rs
|
@ -14,6 +14,12 @@ pub struct TlsSocketSet<'a, 'b, 'c> {
|
|||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TlsSocketHandle(usize);
|
||||
|
||||
impl TlsSocketHandle {
|
||||
pub(crate) fn new(index: usize) -> Self {
|
||||
Self(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> {
|
||||
pub fn new<T>(tls_sockets: T) -> Self
|
||||
where
|
||||
|
@ -50,24 +56,29 @@ impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> {
|
|||
self.tls_sockets[handle.0].as_mut().unwrap()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tls_sockets.len()
|
||||
}
|
||||
|
||||
pub(crate) fn polled_by<DeviceT>(
|
||||
&mut self,
|
||||
sockets: &mut SocketSet,
|
||||
iface: &mut EthernetInterface<DeviceT>,
|
||||
now: Instant
|
||||
) -> smoltcp::Result<bool>
|
||||
where
|
||||
DeviceT: for<'d> Device<'d>
|
||||
{
|
||||
let mut changed = false;
|
||||
for socket in self.tls_sockets.iter_mut() {
|
||||
if socket.is_some() {
|
||||
socket.as_mut()
|
||||
.unwrap()
|
||||
.update_handshake(iface, now)?;
|
||||
if socket.as_mut().unwrap().update_handshake(iface, now)?
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
Ok(changed)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
use embedded_nal as nal;
|
||||
use smoltcp as net;
|
||||
|
||||
use crate::set::TlsSocketHandle as SocketHandle;
|
||||
use crate::set::TlsSocketSet as SocketSet;
|
||||
use crate::tls::TlsSocket;
|
||||
|
||||
use nal::{TcpStack, Mode, SocketAddr, nb};
|
||||
use net::Error;
|
||||
use net::iface::EthernetInterface;
|
||||
use net::time::Instant;
|
||||
use net::phy::Device;
|
||||
use heapless::{Vec, consts::*};
|
||||
|
||||
use core::cell::RefCell;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum NetworkError {
|
||||
NoSocket,
|
||||
ConnectionFailure,
|
||||
ReadFailure,
|
||||
WriteFailure,
|
||||
}
|
||||
|
||||
// Structure for implementaion TcpStack interface
|
||||
pub struct NetworkStack<'a, 'b, 'c> {
|
||||
sockets: RefCell<SocketSet<'a, 'b, 'c>>,
|
||||
next_port: RefCell<u16>,
|
||||
unused_handles: RefCell<Vec<SocketHandle, U16>>
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'c> NetworkStack<'a, 'b, 'c> {
|
||||
pub fn new(sockets: SocketSet<'a, 'b, 'c>) -> Self {
|
||||
let mut vec = Vec::new();
|
||||
log::info!("socket set size: {:?}", sockets.len());
|
||||
for index in 0..sockets.len() {
|
||||
vec.push(
|
||||
SocketHandle::new(index)
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
Self {
|
||||
sockets: RefCell::new(sockets),
|
||||
next_port: RefCell::new(49152),
|
||||
unused_handles: RefCell::new(vec)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_ephemeral_port(&self) -> u16 {
|
||||
// Get the next ephemeral port
|
||||
let current_port = self.next_port.borrow().clone();
|
||||
|
||||
let (next, wrap) = self.next_port.borrow().overflowing_add(1);
|
||||
*self.next_port.borrow_mut() = if wrap { 49152 } else { next };
|
||||
|
||||
return current_port;
|
||||
}
|
||||
|
||||
pub fn poll<DeviceT>(
|
||||
&self,
|
||||
iface: &mut EthernetInterface<DeviceT>,
|
||||
now: Instant,
|
||||
) -> bool
|
||||
where
|
||||
DeviceT: for <'d> Device<'d>
|
||||
{
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
sockets.polled_by(iface, now).map_or(false, |updated| updated)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'c> TcpStack for NetworkStack<'a, 'b, 'c> {
|
||||
type TcpSocket = SocketHandle;
|
||||
type Error = NetworkError;
|
||||
|
||||
fn open(&self, _: Mode) -> Result<Self::TcpSocket, Self::Error> {
|
||||
match self.unused_handles.borrow_mut().pop() {
|
||||
Some(handle) => {
|
||||
// Abort any active connections on the handle.
|
||||
log::info!("Have handle");
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
let mut internal_socket = sockets.get(handle);
|
||||
internal_socket.close();
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
None => {
|
||||
log::info!("Insufficient handles");
|
||||
Err(NetworkError::NoSocket)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
socket: Self::TcpSocket,
|
||||
remote: SocketAddr
|
||||
) -> Result<Self::TcpSocket, Self::Error> {
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
let internal_socket = sockets.get(socket);
|
||||
|
||||
match remote.ip() {
|
||||
embedded_nal::IpAddr::V4(addr) => {
|
||||
let address = {
|
||||
let octets = addr.octets();
|
||||
net::wire::Ipv4Address::new(octets[0], octets[1], octets[2], octets[3])
|
||||
};
|
||||
internal_socket
|
||||
.connect((address, remote.port()), self.get_ephemeral_port())
|
||||
.map_err(|_| NetworkError::ConnectionFailure)?;
|
||||
}
|
||||
embedded_nal::IpAddr::V6(addr) => {
|
||||
let address = {
|
||||
let octets = addr.segments();
|
||||
net::wire::Ipv6Address::new(
|
||||
octets[0], octets[1], octets[2], octets[3], octets[4], octets[5],
|
||||
octets[6], octets[7],
|
||||
)
|
||||
};
|
||||
internal_socket
|
||||
.connect((address, remote.port()), self.get_ephemeral_port())
|
||||
.map_err(|_| NetworkError::ConnectionFailure)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
fn is_connected(
|
||||
&self,
|
||||
socket: &Self::TcpSocket
|
||||
) -> Result<bool, Self::Error> {
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
let internal_socket = sockets.get(*socket);
|
||||
Ok(internal_socket.is_connected().unwrap())
|
||||
}
|
||||
|
||||
fn write(
|
||||
&self,
|
||||
socket: &mut Self::TcpSocket,
|
||||
buffer: &[u8]
|
||||
) -> nb::Result<usize, Self::Error> {
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
let internal_socket = sockets.get(*socket);
|
||||
internal_socket.send_slice(buffer)
|
||||
.map_err(|_| nb::Error::Other(NetworkError::WriteFailure))
|
||||
}
|
||||
|
||||
fn read(
|
||||
&self,
|
||||
socket: &mut Self::TcpSocket,
|
||||
buffer: &mut [u8]
|
||||
) -> nb::Result<usize, Self::Error> {
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
let internal_socket = sockets.get(*socket);
|
||||
internal_socket.recv_slice(buffer)
|
||||
.map_err(|_| nb::Error::Other(NetworkError::ReadFailure))
|
||||
}
|
||||
|
||||
fn close(
|
||||
&self,
|
||||
socket: Self::TcpSocket
|
||||
) -> Result<(), Self::Error> {
|
||||
let mut sockets = self.sockets.borrow_mut();
|
||||
let internal_socket = sockets.get(socket);
|
||||
internal_socket.close();
|
||||
|
||||
self.unused_handles.borrow_mut().push(socket).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
83
src/tls.rs
83
src/tls.rs
|
@ -101,16 +101,28 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
T: Into<IpEndpoint>,
|
||||
U: Into<IpEndpoint>,
|
||||
{
|
||||
// Start TCP handshake
|
||||
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
|
||||
tcp_socket.connect(remote_endpoint, local_endpoint)?;
|
||||
|
||||
// Permit TLS handshake as well
|
||||
let mut session = self.session.borrow_mut();
|
||||
session.connect(
|
||||
tcp_socket.remote_endpoint(),
|
||||
tcp_socket.local_endpoint()
|
||||
);
|
||||
|
||||
// Start TCP handshake
|
||||
if !tcp_socket.is_open() {
|
||||
tcp_socket.connect(remote_endpoint, local_endpoint)?;
|
||||
// Start TLS handshake if TCP handshake will commence
|
||||
session.connect(
|
||||
tcp_socket.remote_endpoint(),
|
||||
tcp_socket.local_endpoint()
|
||||
);
|
||||
} else {
|
||||
// Also start TLS handshake if for some reason TCP is ready,
|
||||
// and TLS is idle
|
||||
if session.get_tls_state() == TlsState::DEFAULT {
|
||||
session.connect(
|
||||
tcp_socket.remote_endpoint(),
|
||||
tcp_socket.local_endpoint()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -141,7 +153,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
DeviceT: for<'d> Device<'d>
|
||||
{
|
||||
// Poll the TCP socket, no matter what
|
||||
iface.poll(&mut self.sockets, now)?;
|
||||
let propagated_poll = iface.poll(&mut self.sockets, now)?;
|
||||
|
||||
// Handle TLS handshake through TLS states
|
||||
let tls_state = {
|
||||
|
@ -159,7 +171,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
// Close TCP socket if necessary
|
||||
if tcp_state == TcpState::Established && tls_state == TlsState::DEFAULT {
|
||||
self.sockets.get::<TcpSocket>(self.tcp_handle).close();
|
||||
return Ok(false);
|
||||
return Ok(propagated_poll);
|
||||
}
|
||||
|
||||
// Skip handshake processing if it is already completed
|
||||
|
@ -187,7 +199,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
session.get_remote_endpoint(),
|
||||
session.get_local_endpoint()
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For any other functioning state, the TCP connection being not
|
||||
|
@ -196,12 +208,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
_ => {
|
||||
let mut session = self.session.borrow_mut();
|
||||
session.reset_state();
|
||||
log::info!("TLS socket resets after TCP socket closed")
|
||||
log::info!("TLS socket resets after TCP socket closed");
|
||||
}
|
||||
}
|
||||
|
||||
// Terminate the procedure, as no processing is necessary
|
||||
return Ok(false);
|
||||
return Ok(propagated_poll);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -312,7 +324,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
// `close()` to the TCP socket
|
||||
self.session.borrow_mut().reset_state();
|
||||
|
||||
return Ok(false);
|
||||
return Ok(propagated_poll);
|
||||
}
|
||||
|
||||
// Handle TLS handshake through TLS states
|
||||
|
@ -555,7 +567,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
|
||||
// There is no need to care about handshake if it was completed
|
||||
TlsState::CLIENT_CONNECTED => {
|
||||
return Ok(true);
|
||||
return Ok(propagated_poll);
|
||||
}
|
||||
|
||||
// This state waits for Client Hello handshake from a client
|
||||
|
@ -793,7 +805,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
// There is no need to care about handshake if it was completed
|
||||
// This is to prevent accidental dequeing of application data
|
||||
TlsState::SERVER_CONNECTED => {
|
||||
return Ok(true);
|
||||
return Ok(propagated_poll);
|
||||
}
|
||||
|
||||
// Other states
|
||||
|
@ -811,14 +823,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
// Check if there are bytes enqueued in the recv buffer
|
||||
// No need to do further dequeuing if there are no receivable bytes
|
||||
if !tcp_socket.can_recv() {
|
||||
return Ok(self.session.borrow().has_completed_handshake())
|
||||
return Ok(propagated_poll)
|
||||
}
|
||||
|
||||
// Peak into the first 5 bytes (TLS record layer)
|
||||
// This tells the length of the entire record
|
||||
let length = match tcp_socket.peek(5) {
|
||||
Ok(bytes) => NetworkEndian::read_u16(&bytes[3..5]),
|
||||
_ => return Ok(self.session.borrow().has_completed_handshake())
|
||||
_ => return Ok(propagated_poll)
|
||||
};
|
||||
|
||||
// Recv the entire TLS record
|
||||
|
@ -830,7 +842,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
// Parse the bytes representation of a TLS record
|
||||
let (repr_slice, mut repr) = match parse_tls_repr(&tls_repr_vec) {
|
||||
Ok((_, (repr_slice, repr))) => (repr_slice, repr),
|
||||
_ => return Ok(self.session.borrow().has_completed_handshake())
|
||||
_ => return Ok(propagated_poll)
|
||||
};
|
||||
|
||||
// Process record base on content type
|
||||
|
@ -888,16 +900,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
AlertType::UnexpectedMessage,
|
||||
&inner_plaintext[..content_type_index]
|
||||
);
|
||||
return Ok(false);
|
||||
return Ok(propagated_poll);
|
||||
},
|
||||
TlsContentType::Alert => {
|
||||
self.session.borrow_mut().reset_state();
|
||||
return Ok(false);
|
||||
return Ok(propagated_poll);
|
||||
},
|
||||
TlsContentType::ApplicationData => {
|
||||
return Ok(
|
||||
self.session.borrow().has_completed_handshake()
|
||||
);
|
||||
return Ok(propagated_poll);
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
|
@ -920,7 +930,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
handshake: Some(handshake_repr)
|
||||
}
|
||||
).is_err() {
|
||||
return Ok(self.session.borrow().has_completed_handshake())
|
||||
return Ok(propagated_poll)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -928,7 +938,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
TlsContentType::ChangeCipherSpec |
|
||||
TlsContentType::Handshake => {
|
||||
if self.process(repr_slice, repr).is_err() {
|
||||
return Ok(self.session.borrow().has_completed_handshake())
|
||||
return Ok(propagated_poll)
|
||||
}
|
||||
log::info!("Processed record");
|
||||
},
|
||||
|
@ -948,7 +958,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(self.session.borrow().has_completed_handshake())
|
||||
Ok(propagated_poll)
|
||||
}
|
||||
|
||||
// Process TLS ingress during handshake
|
||||
|
@ -1856,15 +1866,16 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
Ok(actual_application_data_length)
|
||||
}
|
||||
|
||||
pub fn send_slice(&mut self, data: &[u8]) -> Result<()> {
|
||||
pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
|
||||
// If the handshake is not completed, do not push bytes onto the buffer
|
||||
// through TlsSocket.send_slice()
|
||||
// Handshake send should be through TCPSocket directly.
|
||||
let mut session = self.session.borrow_mut();
|
||||
if session.get_tls_state() != TlsState::CLIENT_CONNECTED &&
|
||||
session.get_tls_state() != TlsState::SERVER_CONNECTED {
|
||||
return Ok(());
|
||||
return Ok(0);
|
||||
}
|
||||
let data_length = data.len();
|
||||
|
||||
// Sending order:
|
||||
// 1. Associated data/ TLS Record layer
|
||||
|
@ -1877,7 +1888,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
];
|
||||
|
||||
NetworkEndian::write_u16(&mut associated_data[3..5],
|
||||
u16::try_from(data.len()).unwrap() // Payload length
|
||||
u16::try_from(data_length).unwrap() // Payload length
|
||||
+ 1 // Content type length
|
||||
+ 16 // Auth tag length
|
||||
);
|
||||
|
@ -1900,7 +1911,15 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
|
|||
tcp_socket.send_slice(&vec)?;
|
||||
tcp_socket.send_slice(&tag)?;
|
||||
|
||||
Ok(())
|
||||
Ok(data_length)
|
||||
}
|
||||
|
||||
pub fn is_connected(&self) -> Result<bool> {
|
||||
let session = self.session.borrow();
|
||||
Ok(
|
||||
session.get_tls_state() == TlsState::CLIENT_CONNECTED ||
|
||||
session.get_tls_state() == TlsState::SERVER_CONNECTED
|
||||
)
|
||||
}
|
||||
|
||||
// Send `Close notify` alert to remote side
|
||||
|
@ -1936,7 +1955,7 @@ use core::fmt;
|
|||
impl<'a, 'b, 'c> fmt::Write for TlsSocket<'a, 'b, 'c> {
|
||||
fn write_str(&mut self, slice: &str) -> fmt::Result {
|
||||
let slice = slice.as_bytes();
|
||||
if self.send_slice(slice) == Ok(()) {
|
||||
if self.send_slice(slice) == Ok(slice.len()) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(fmt::Error)
|
||||
|
|
Loading…
Reference in New Issue