socket: owns tcp socket

master
occheung 2020-12-02 11:20:31 +08:00
parent 4f8b273e86
commit 25cdd23406
4 changed files with 281 additions and 180 deletions

View File

@ -58,6 +58,6 @@ pub fn poll<DeviceT>(
where
DeviceT: for<'d> Device<'d>
{
tls_sockets.polled_by(sockets)?;
tls_sockets.polled_by(sockets, iface, now)?;
iface.poll(sockets, now).map_err(Error::PropagatedError)
}

View File

@ -66,10 +66,9 @@ fn main() {
let mut tls_socket = unsafe {
let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]);
let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]);
let tcp_socket = smoltcp::socket::TcpSocket::new(rx_buffer, tx_buffer);
TlsSocket::new(
&mut sockets,
rx_buffer,
tx_buffer,
tcp_socket,
&mut RNG,
None
)

View File

@ -3,25 +3,28 @@ use smoltcp as net;
use managed::ManagedSlice;
use crate::tls::TlsSocket;
use net::socket::SocketSet;
use net::phy::Device;
use net::iface::EthernetInterface;
use net::time::Instant;
pub struct TlsSocketSet<'a> {
tls_sockets: ManagedSlice<'a, Option<TlsSocket<'a>>>
pub struct TlsSocketSet<'a, 'b, 'c> {
tls_sockets: ManagedSlice<'a, Option<TlsSocket<'a, 'b, 'c>>>
}
#[derive(Clone, Copy, Debug)]
pub struct TlsSocketHandle(usize);
impl<'a> TlsSocketSet<'a> {
impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> {
pub fn new<T>(tls_sockets: T) -> Self
where
T: Into<ManagedSlice<'a, Option<TlsSocket<'a>>>>
T: Into<ManagedSlice<'a, Option<TlsSocket<'a, 'b, 'c>>>>
{
Self {
tls_sockets: tls_sockets.into()
}
}
pub fn add(&mut self, socket: TlsSocket<'a>) -> TlsSocketHandle
pub fn add(&mut self, socket: TlsSocket<'a, 'b, 'c>) -> TlsSocketHandle
{
for (index, slot) in self.tls_sockets.iter_mut().enumerate() {
if slot.is_none() {
@ -43,20 +46,24 @@ impl<'a> TlsSocketSet<'a> {
}
}
pub fn get(&mut self, handle: TlsSocketHandle) -> &mut TlsSocket<'a> {
pub fn get(&mut self, handle: TlsSocketHandle) -> &mut TlsSocket<'a, 'b, 'c> {
self.tls_sockets[handle.0].as_mut().unwrap()
}
pub(crate) fn polled_by(
pub(crate) fn polled_by<DeviceT>(
&mut self,
sockets: &mut SocketSet
sockets: &mut SocketSet,
iface: &mut EthernetInterface<DeviceT>,
now: Instant
) -> smoltcp::Result<bool>
where
DeviceT: for<'d> Device<'d>
{
for socket in self.tls_sockets.iter_mut() {
if socket.is_some() {
socket.as_mut()
.unwrap()
.update_handshake(sockets)?;
.update_handshake(iface, now)?;
}
}

View File

@ -6,6 +6,9 @@ use smoltcp::socket::TcpSocketBuffer;
use smoltcp::wire::IpEndpoint;
use smoltcp::Result;
use smoltcp::Error;
use smoltcp::phy::Device;
use smoltcp::iface::EthernetInterface;
use smoltcp::time::Instant;
use byteorder::{ByteOrder, NetworkEndian};
use generic_array::GenericArray;
@ -57,30 +60,30 @@ pub(crate) enum TlsState {
SERVER_CONNECTED,
}
pub struct TlsSocket<'b>
pub struct TlsSocket<'a, 'b, 'c>
{
// Locally owned SocketSet, solely containing 1 TCP socket
sockets: SocketSet<'a, 'b, 'c>,
tcp_handle: SocketHandle,
rng: &'b mut dyn crate::TlsRng,
session: RefCell<Session<'b>>,
}
impl<'b> TlsSocket<'b> {
pub fn new<'a, 'c>(
sockets: &mut SocketSet<'a, 'b, 'c>,
rx_buffer: TcpSocketBuffer<'b>,
tx_buffer: TcpSocketBuffer<'b>,
impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
pub fn new(
tcp_socket: TcpSocket<'b>,
rng: &'b mut dyn crate::TlsRng,
certificate_with_key: Option<(
crate::session::CertificatePrivateKey,
Vec<&'b [u8]>
)>
) -> Self
where
'b: 'c,
{
let tcp_socket = TcpSocket::new(rx_buffer, tx_buffer);
let socket_set_entries: [_; 1] = Default::default();
let mut sockets = SocketSet::new(socket_set_entries);
let tcp_handle = sockets.add(tcp_socket);
TlsSocket {
sockets,
tcp_handle,
rng,
session: RefCell::new(
@ -89,26 +92,8 @@ impl<'b> TlsSocket<'b> {
}
}
// pub fn from_tcp_handle(
// tcp_handle: SocketHandle,
// rng: &'s mut dyn crate::TlsRng,
// certificate_with_key: Option<(
// crate::session::CertificatePrivateKey,
// Vec<&'s [u8]>
// )>
// ) -> Self {
// TlsSocket {
// tcp_handle,
// rng,
// session: RefCell::new(
// Session::new(TlsRole::Client, certificate_with_key)
// ),
// }
// }
pub fn connect<T, U>(
&mut self,
sockets: &mut SocketSet,
remote_endpoint: T,
local_endpoint: U,
) -> Result<()>
@ -117,7 +102,7 @@ impl<'b> TlsSocket<'b> {
U: Into<IpEndpoint>,
{
// Start TCP handshake
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.connect(remote_endpoint, local_endpoint)?;
// Permit TLS handshake as well
@ -131,14 +116,13 @@ impl<'b> TlsSocket<'b> {
pub fn listen<T>(
&mut self,
sockets: &mut SocketSet,
local_endpoint: T
) -> Result<()>
where
T: Into<IpEndpoint>
{
// Listen from TCP socket
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.listen(local_endpoint)?;
// Update tls session to server_start
@ -148,7 +132,17 @@ impl<'b> TlsSocket<'b> {
Ok(())
}
pub fn update_handshake(&mut self, sockets: &mut SocketSet) -> Result<bool> {
pub fn update_handshake<DeviceT>(
&mut self,
iface: &mut EthernetInterface<DeviceT>,
now: Instant
) -> Result<bool>
where
DeviceT: for<'d> Device<'d>
{
// Poll the TCP socket, no matter what
iface.poll(&mut self.sockets, now)?;
// Handle TLS handshake through TLS states
let tls_state = {
self.session.borrow().get_tls_state()
@ -156,8 +150,7 @@ impl<'b> TlsSocket<'b> {
// Check TCP socket/ TLS session
{
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let tls_socket = self.session.borrow();
let tcp_state = self.sockets.get::<TcpSocket>(self.tcp_handle).state();
// // Check if it should connect to client or not
// if tls_socket.get_session_role() != crate::session::TlsRole::Client {
@ -168,21 +161,20 @@ impl<'b> TlsSocket<'b> {
// Skip handshake processing if it is already completed
// However, redo TCP handshake if TLS socket is trying to connect and
// TCP socket is not connected
if tcp_socket.state() != TcpState::Established {
if tcp_state != TcpState::Established {
if tls_state == TlsState::CLIENT_START {
// Restart TCP handshake is it is closed for some reason
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
let session = self.session.borrow();
if !tcp_socket.is_open() {
tcp_socket.connect(
tls_socket.get_remote_endpoint(),
tls_socket.get_local_endpoint()
session.get_remote_endpoint(),
session.get_local_endpoint()
)?;
}
return Ok(false);
} else {
// Do nothing, either handshake failed or the socket closed
// after finishing the handshake
return Ok(false);
}
// Terminate the procedure, as no processing is necessary
return Ok(false);
}
}
// Handle TLS handshake through TLS states
@ -204,7 +196,8 @@ impl<'b> TlsSocket<'b> {
.client_hello(&ecdh_secret, &x25519_secret, random, session_id.clone());
{
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
let mut session = self.session.borrow_mut();
tcp_socket.send(
|data| {
// Enqueue tls representation without extra allocation
@ -218,7 +211,6 @@ impl<'b> TlsSocket<'b> {
// No sequence number calculation in CH
// because there is no encryption
// Still, data needs to be hashed
let mut session = self.session.borrow_mut();
session.client_update_for_ch(
ecdh_secret,
x25519_secret,
@ -331,7 +323,7 @@ impl<'b> TlsSocket<'b> {
(certificates_total_length, buffer_vec)
};
self.send_application_slice(sockets, &mut buffer_vec.clone())?;
self.send_application_slice(&mut buffer_vec.clone())?;
// Update session
let buffer_vec_length = buffer_vec.len();
@ -388,7 +380,7 @@ impl<'b> TlsSocket<'b> {
// Push content byte (handshake: 22)
verify_buffer_vec.push(22);
self.send_application_slice(sockets, &mut verify_buffer_vec.clone())?;
self.send_application_slice(&mut verify_buffer_vec.clone())?;
// Update session
let cert_verify_len = verify_buffer_vec.len();
@ -416,7 +408,7 @@ impl<'b> TlsSocket<'b> {
buffer.push(22).unwrap();
buffer
};
self.send_application_slice(sockets, &mut inner_plaintext.clone())?;
self.send_application_slice(&mut inner_plaintext.clone())?;
let inner_plaintext_length = inner_plaintext.len();
self.session.borrow_mut()
@ -476,7 +468,7 @@ impl<'b> TlsSocket<'b> {
ecdhe_public_key
);
{
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
let mut session = self.session.borrow_mut();
tcp_socket.send(
|data| {
@ -499,6 +491,8 @@ impl<'b> TlsSocket<'b> {
)?;
}
log::info!("sent server hello");
// Construct and send minimalistic EE
let inner_plaintext: [u8; 7] = [
0x08, // EE type
@ -506,7 +500,7 @@ impl<'b> TlsSocket<'b> {
0x00, 0x00, // Length of extensions: 0
22 // Content type of InnerPlainText
];
self.send_application_slice(sockets, &mut inner_plaintext.clone())?;
self.send_application_slice(&mut inner_plaintext.clone())?;
let inner_plaintext_length = inner_plaintext.len();
{
@ -516,6 +510,8 @@ impl<'b> TlsSocket<'b> {
);
}
log::info!("sent encrypted extension");
// TODO: Option to allow a certificate request
// Construct and send server certificate handshake content
@ -579,13 +575,14 @@ impl<'b> TlsSocket<'b> {
inner_plaintext
};
self.send_application_slice(sockets, &mut inner_plaintext.clone())?;
self.send_application_slice(&mut inner_plaintext.clone())?;
let inner_plaintext_length = inner_plaintext.len();
// Update session
{
self.session.borrow_mut()
.server_update_for_sent_certificate(&inner_plaintext[..(inner_plaintext_length-1)]);
}
log::info!("sent certificate");
// Construct and send certificate verify
let mut inner_plaintext = {
@ -619,10 +616,7 @@ impl<'b> TlsSocket<'b> {
inner_plaintext
};
self.send_application_slice(
sockets,
&mut inner_plaintext.clone()
)?;
self.send_application_slice(&mut inner_plaintext.clone())?;
let inner_plaintext_length = inner_plaintext.len();
{
@ -631,6 +625,7 @@ impl<'b> TlsSocket<'b> {
&inner_plaintext[..(inner_plaintext_length-1)]
);
}
log::info!("sent certificate verify");
// Construct and send server finished
let inner_plaintext: HeaplessVec<u8, U64> = {
@ -647,13 +642,14 @@ impl<'b> TlsSocket<'b> {
buffer.push(22).unwrap();
buffer
};
self.send_application_slice(sockets, &mut inner_plaintext.clone())?;
self.send_application_slice(&mut inner_plaintext.clone())?;
let inner_plaintext_length = inner_plaintext.len();
{
self.session.borrow_mut()
.server_update_for_server_finished(&inner_plaintext[..(inner_plaintext_length-1)]);
}
log::info!("sent client finished");
}
// There is no need to care about handshake if it was completed
@ -669,116 +665,215 @@ impl<'b> TlsSocket<'b> {
// Read for TLS packet
// Proposition: Decouple all data from TLS record layer before processing
// Recouple a brand new TLS record wrapper
// Use recv to avoid buffer allocation
// Use peek & recv to avoid buffer allocation
{
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.recv(
|buffer| {
// log::info!("Received Buffer: {:?}", buffer);
let buffer_size = buffer.len();
let tls_repr_vec = {
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
// Provide a way to end the process early
if buffer_size == 0 {
return (0, ())
}
log::info!("Received something");
let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new();
let mut bytes = &buffer[..buffer_size];
// Sequentially push reprs into vec
loop {
match parse_tls_repr(bytes) {
Ok((rest, (repr_slice, repr))) => {
tls_repr_vec.push(
(repr_slice, repr)
);
if rest.len() == 0 {
break;
} else {
bytes = rest;
}
},
// Dequeue everything and abort processing if it is malformed
_ => return (buffer_size, ())
};
}
// Sequencially process the representations in vector
// Decrypt and split the handshake if necessary
let tls_repr_vec_size = tls_repr_vec.len();
for _index in 0..tls_repr_vec_size {
let (repr_slice, mut repr) = tls_repr_vec.remove(0);
// Process record base on content type
log::info!("Record type: {:?}", repr.content_type);
if repr.content_type == TlsContentType::ApplicationData {
log::info!("Found application data");
// Take the payload out of TLS Record and decrypt
let mut app_data = repr.payload.take().unwrap();
let mut associated_data = [0; 5];
associated_data[0] = repr.content_type.into();
NetworkEndian::write_u16(
&mut associated_data[1..3],
repr.version.into()
);
NetworkEndian::write_u16(
&mut associated_data[3..5],
repr.length
);
{
let mut session = self.session.borrow_mut();
session.decrypt_in_place_detached(
&associated_data,
&mut app_data
).unwrap();
session.increment_remote_sequence_number();
}
// Discard last 16 bytes (auth tag)
let inner_plaintext = &app_data[..app_data.len()-16];
let (inner_content_type, _) = get_content_type_inner_plaintext(
inner_plaintext
);
if inner_content_type != TlsContentType::Handshake {
// Silently ignore non-handshakes
continue;
}
let (_, mut inner_handshakes) = complete(
parse_inner_plaintext_for_handshake
)(inner_plaintext).unwrap();
// Sequentially process all handshakes
let num_of_handshakes = inner_handshakes.len();
for _ in 0..num_of_handshakes {
let (handshake_slice, handshake_repr) = inner_handshakes.remove(0);
if self.process(
handshake_slice,
TlsRepr {
content_type: TlsContentType::Handshake,
version: repr.version,
length: u16::try_from(handshake_repr.length).unwrap() + 4,
payload: None,
handshake: Some(handshake_repr)
}
).is_err() {
return (buffer_size, ())
}
}
}
else {
if self.process(repr_slice, repr).is_err() {
return (buffer_size, ())
}
log::info!("Processed record");
}
}
(buffer_size, ())
// Check if there are bytes enqueued in the recv buffer
// No need to do further dequeuing if there are no receivable bytes
if !tcp_socket.can_recv() {
return Ok(self.session.borrow().has_completed_handshake())
}
)?;
log::info!("Tls can recv");
// Peak into the first 5 bytes (TLS record layer)
// This tells the length of the entire record
let length = match tcp_socket.peek(5) {
Ok(bytes) => NetworkEndian::read_u16(&bytes[3..5]),
_ => return Ok(self.session.borrow().has_completed_handshake())
};
log::info!("tls record length: {:?}", length);
// Recv the entire TLS record
tcp_socket.recv(
|buffer| ((length + 5).into(), Vec::from(&buffer[..(length + 5).into()]))
).unwrap()
};
// Parse the bytes representation of a TLS record
let (repr_slice, mut repr) = match parse_tls_repr(&tls_repr_vec) {
Ok((_, (repr_slice, repr))) => (repr_slice, repr),
_ => return Ok(self.session.borrow().has_completed_handshake())
};
// Process record base on content type
log::info!("Record type: {:?}", repr.content_type);
if repr.content_type == TlsContentType::ApplicationData {
log::info!("Found application data");
// Take the payload out of TLS Record and decrypt
let mut app_data = repr.payload.take().unwrap();
let mut associated_data = [0; 5];
associated_data[0] = repr.content_type.into();
NetworkEndian::write_u16(
&mut associated_data[1..3],
repr.version.into()
);
NetworkEndian::write_u16(
&mut associated_data[3..5],
repr.length
);
{
let mut session = self.session.borrow_mut();
session.decrypt_in_place_detached(
&associated_data,
&mut app_data
).unwrap();
session.increment_remote_sequence_number();
}
// Discard last 16 bytes (auth tag)
let inner_plaintext = &app_data[..app_data.len()-16];
let (inner_content_type, _) = get_content_type_inner_plaintext(
inner_plaintext
);
if inner_content_type != TlsContentType::Handshake {
// Silently ignore non-handshakes
return Ok(self.session.borrow().has_completed_handshake())
}
let (_, mut inner_handshakes) = complete(
parse_inner_plaintext_for_handshake
)(inner_plaintext).unwrap();
// Sequentially process all handshakes
let num_of_handshakes = inner_handshakes.len();
for _ in 0..num_of_handshakes {
let (handshake_slice, handshake_repr) = inner_handshakes.remove(0);
if self.process(
handshake_slice,
TlsRepr {
content_type: TlsContentType::Handshake,
version: repr.version,
length: u16::try_from(handshake_repr.length).unwrap() + 4,
payload: None,
handshake: Some(handshake_repr)
}
).is_err() {
return Ok(self.session.borrow().has_completed_handshake())
}
}
} else {
if self.process(repr_slice, repr).is_err() {
return Ok(self.session.borrow().has_completed_handshake())
}
log::info!("Processed record");
}
// // Finally dequeue the record from buffer
// if tcp_socket.recv(|_| (length.into(), ())).is_err() {
// return Ok(self.session.borrow().has_completed_handshake())
// }
// tcp_socket.recv(
// |buffer| {
// // log::info!("Received Buffer: {:?}", buffer);
// let buffer_size = buffer.len();
// // Provide a way to end the process early
// if buffer_size == 0 {
// return (0, ())
// }
// log::info!("Received something");
// let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new();
// let mut bytes = &buffer[..buffer_size];
// // Sequentially push reprs into vec
// loop {
// match parse_tls_repr(bytes) {
// Ok((rest, (repr_slice, repr))) => {
// tls_repr_vec.push(
// (repr_slice, repr)
// );
// if rest.len() == 0 {
// break;
// } else {
// bytes = rest;
// }
// },
// // Dequeue everything and abort processing if it is malformed
// _ => return (buffer_size, ())
// };
// }
// // Sequencially process the representations in vector
// // Decrypt and split the handshake if necessary
// let tls_repr_vec_size = tls_repr_vec.len();
// for _index in 0..tls_repr_vec_size {
// let (repr_slice, mut repr) = tls_repr_vec.remove(0);
// // Process record base on content type
// log::info!("Record type: {:?}", repr.content_type);
// if repr.content_type == TlsContentType::ApplicationData {
// log::info!("Found application data");
// // Take the payload out of TLS Record and decrypt
// let mut app_data = repr.payload.take().unwrap();
// let mut associated_data = [0; 5];
// associated_data[0] = repr.content_type.into();
// NetworkEndian::write_u16(
// &mut associated_data[1..3],
// repr.version.into()
// );
// NetworkEndian::write_u16(
// &mut associated_data[3..5],
// repr.length
// );
// {
// let mut session = self.session.borrow_mut();
// session.decrypt_in_place_detached(
// &associated_data,
// &mut app_data
// ).unwrap();
// session.increment_remote_sequence_number();
// }
// // Discard last 16 bytes (auth tag)
// let inner_plaintext = &app_data[..app_data.len()-16];
// let (inner_content_type, _) = get_content_type_inner_plaintext(
// inner_plaintext
// );
// if inner_content_type != TlsContentType::Handshake {
// // Silently ignore non-handshakes
// continue;
// }
// let (_, mut inner_handshakes) = complete(
// parse_inner_plaintext_for_handshake
// )(inner_plaintext).unwrap();
// // Sequentially process all handshakes
// let num_of_handshakes = inner_handshakes.len();
// for _ in 0..num_of_handshakes {
// let (handshake_slice, handshake_repr) = inner_handshakes.remove(0);
// if self.process(
// handshake_slice,
// TlsRepr {
// content_type: TlsContentType::Handshake,
// version: repr.version,
// length: u16::try_from(handshake_repr.length).unwrap() + 4,
// payload: None,
// handshake: Some(handshake_repr)
// }
// ).is_err() {
// return (buffer_size, ())
// }
// }
// }
// else {
// if self.process(repr_slice, repr).is_err() {
// return (buffer_size, ())
// }
// log::info!("Processed record");
// }
// }
// (buffer_size, ())
// }
// )?;
}
Ok(self.session.borrow().has_completed_handshake())
@ -1433,8 +1528,8 @@ impl<'b> TlsSocket<'b> {
// Input should be inner plaintext
// Note: Do not put this slice into the transcript hash. It is polluted.
// TODO: Rename this function. It is only good for client finished
fn send_application_slice(&self, sockets: &mut SocketSet, slice: &mut [u8]) -> Result<()> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
fn send_application_slice(&mut self, slice: &mut [u8]) -> Result<()> {
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_send() {
return Err(Error::Illegal);
}
@ -1471,8 +1566,8 @@ impl<'b> TlsSocket<'b> {
Ok(())
}
pub fn recv_slice(&self, sockets: &mut SocketSet, data: &mut [u8]) -> Result<usize> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_recv() {
return Ok(0);
}
@ -1561,7 +1656,7 @@ impl<'b> TlsSocket<'b> {
Ok(actual_application_data_length)
}
pub fn send_slice(&self, sockets: &mut SocketSet, data: &[u8]) -> Result<()> {
pub fn send_slice(&mut self, data: &[u8]) -> Result<()> {
// If the handshake is not completed, do not push bytes onto the buffer
// through TlsSocket.send_slice()
// Handshake send should be through TCPSocket directly.
@ -1596,7 +1691,7 @@ impl<'b> TlsSocket<'b> {
).unwrap();
session.increment_local_sequence_number();
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tcp_socket = self.sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_send() {
return Err(Error::Illegal);
}