tls: reduce redundant array alloc

This commit is contained in:
occheung 2020-11-24 13:10:36 +08:00
parent ee7df70e6f
commit ca3f548727
4 changed files with 128 additions and 163 deletions

View File

@ -46,7 +46,6 @@ use net::iface::EthernetInterface;
use net::time::Instant; use net::time::Instant;
use net::phy::Device; use net::phy::Device;
use crate::tls::TlsSocket;
use crate::set::TlsSocketSet; use crate::set::TlsSocketSet;
// One-call function for polling all sockets within socket set // One-call function for polling all sockets within socket set

View File

@ -149,8 +149,6 @@ impl<'a> Session<'a> {
} }
// State transition from WAIT_SH to WAIT_EE // State transition from WAIT_SH to WAIT_EE
// TODO: Memory allocation
// It current dumps too much memory onto the stack on invocation
pub(crate) fn client_update_for_sh( pub(crate) fn client_update_for_sh(
&mut self, &mut self,
cipher_suite: CipherSuite, cipher_suite: CipherSuite,

View File

@ -2,21 +2,7 @@ use smoltcp as net;
use managed::ManagedSlice; use managed::ManagedSlice;
use crate::tls::TlsSocket; use crate::tls::TlsSocket;
use net::socket::SocketSetItem;
use net::socket::SocketSet; use net::socket::SocketSet;
use net::socket::SocketHandle;
use net::socket::Socket;
use net::socket::TcpSocket;
use net::socket::AnySocket;
use net::socket::SocketRef;
use net::iface::EthernetInterface;
use net::time::Instant;
use net::phy::Device;
use core::convert::From;
use core::cell::RefCell;
use alloc::vec::Vec;
pub struct TlsSocketSet<'a> { pub struct TlsSocketSet<'a> {
tls_sockets: ManagedSlice<'a, Option<TlsSocket<'a>>> tls_sockets: ManagedSlice<'a, Option<TlsSocket<'a>>>

View File

@ -3,13 +3,9 @@ use smoltcp::socket::TcpState;
use smoltcp::socket::SocketHandle; use smoltcp::socket::SocketHandle;
use smoltcp::socket::SocketSet; use smoltcp::socket::SocketSet;
use smoltcp::socket::TcpSocketBuffer; use smoltcp::socket::TcpSocketBuffer;
use smoltcp::socket::SocketRef;
use smoltcp::wire::IpEndpoint; use smoltcp::wire::IpEndpoint;
use smoltcp::Result; use smoltcp::Result;
use smoltcp::Error; use smoltcp::Error;
use smoltcp::iface::EthernetInterface;
use smoltcp::time::Instant;
use smoltcp::phy::Device;
use byteorder::{ByteOrder, NetworkEndian}; use byteorder::{ByteOrder, NetworkEndian};
use generic_array::GenericArray; use generic_array::GenericArray;
@ -18,7 +14,6 @@ use core::convert::TryFrom;
use core::convert::TryInto; use core::convert::TryInto;
use core::cell::RefCell; use core::cell::RefCell;
use rand_core::{RngCore, CryptoRng};
use p256::{EncodedPoint, ecdh::EphemeralSecret}; use p256::{EncodedPoint, ecdh::EphemeralSecret};
use ccm::consts::*; use ccm::consts::*;
@ -133,7 +128,7 @@ impl<'s> TlsSocket<'s> {
// Check TCP socket/ TLS session // Check TCP socket/ TLS session
{ {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tls_socket = self.session.borrow(); let tls_socket = self.session.borrow();
// Check if it should connect to client or not // Check if it should connect to client or not
if tls_socket.get_session_role() != crate::session::TlsRole::Client { if tls_socket.get_session_role() != crate::session::TlsRole::Client {
@ -178,22 +173,34 @@ impl<'s> TlsSocket<'s> {
let repr = TlsRepr::new() let repr = TlsRepr::new()
.client_hello(&ecdh_secret, &x25519_secret, random, session_id.clone()); .client_hello(&ecdh_secret, &x25519_secret, random, session_id.clone());
// Update hash function with client hello handshake {
let mut array = [0; 512]; let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut buffer = TlsBuffer::new(&mut array); tcp_socket.send(
buffer.enqueue_tls_repr(repr)?; |data| {
// Enqueue tls representation without extra allocation
let mut buffer = TlsBuffer::new(data);
if buffer.enqueue_tls_repr(repr).is_err() {
return (0, ())
}
let slice: &[u8] = buffer.into(); let slice: &[u8] = buffer.into();
// Send the packet // Update the session
self.send_tls_slice(sockets, slice)?; // No sequence number calculation in CH
// because there is no encryption
// Update TLS session // Still, data needs to be hashed
self.session.borrow_mut().client_update_for_ch( let mut session = self.session.borrow_mut();
session.client_update_for_ch(
ecdh_secret, ecdh_secret,
x25519_secret, x25519_secret,
session_id, session_id,
&slice[5..] &slice[5..]
); );
// Finally send the data
(slice.len(), ())
}
)?;
}
}, },
// TLS Client wait for Server Hello // TLS Client wait for Server Hello
@ -389,20 +396,48 @@ impl<'s> TlsSocket<'s> {
} }
} }
// Read for TLS packet // Read for TLS packet
// Proposition: Decouple all data from TLS record layer before processing // Proposition: Decouple all data from TLS record layer before processing
// Recouple a brand new TLS record wrapper // Recouple a brand new TLS record wrapper
let mut array: [u8; 2048] = [0; 2048]; // Use recv to avoid buffer allocation
let mut tls_repr_vec = self.recv_tls_repr(sockets, &mut array)?; {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.recv(
|buffer| {
let buffer_size = buffer.len();
// Take the TLS representation out of the vector, let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new();
// Process as a queue let mut bytes = &buffer[..buffer_size];
// Sequentially push reprs into vec
loop {
match parse_tls_repr(bytes) {
Ok((rest, (repr_slice, repr))) => {
tls_repr_vec.push(
(repr_slice, repr)
);
if rest.len() == 0 {
break;
} else {
bytes = rest;
}
},
// Dequeue everything and abort processing if it is malformed
_ => return (buffer_size, ())
};
}
// Sequencially process the representations in vector
// Decrypt and split the handshake if necessary
let tls_repr_vec_size = tls_repr_vec.len(); let tls_repr_vec_size = tls_repr_vec.len();
for _index in 0..tls_repr_vec_size { for _index in 0..tls_repr_vec_size {
let (repr_slice, mut repr) = tls_repr_vec.remove(0); let (repr_slice, mut repr) = tls_repr_vec.remove(0);
// Process record base on content type // Process record base on content type
log::info!("Record type: {:?}", repr.content_type); log::info!("Record type: {:?}", repr.content_type);
if repr.content_type == TlsContentType::ApplicationData { if repr.content_type == TlsContentType::ApplicationData {
log::info!("Found application data"); log::info!("Found application data");
// Take the payload out of TLS Record and decrypt // Take the payload out of TLS Record and decrypt
@ -443,7 +478,7 @@ impl<'s> TlsSocket<'s> {
let num_of_handshakes = inner_handshakes.len(); let num_of_handshakes = inner_handshakes.len();
for _ in 0..num_of_handshakes { for _ in 0..num_of_handshakes {
let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); let (handshake_slice, handshake_repr) = inner_handshakes.remove(0);
self.process( if self.process(
handshake_slice, handshake_slice,
TlsRepr { TlsRepr {
content_type: TlsContentType::Handshake, content_type: TlsContentType::Handshake,
@ -452,16 +487,23 @@ impl<'s> TlsSocket<'s> {
payload: None, payload: None,
handshake: Some(handshake_repr) handshake: Some(handshake_repr)
} }
)?; ).is_err() {
return (buffer_size, ())
}
} }
} }
else { else {
self.process(repr_slice, repr)?; if self.process(repr_slice, repr).is_err() {
return (buffer_size, ())
}
log::info!("Processed record"); log::info!("Processed record");
} }
} }
(buffer_size, ())
}
)?;
}
Ok(self.session.borrow().has_completed_handshake()) Ok(self.session.borrow().has_completed_handshake())
} }
@ -826,28 +868,6 @@ impl<'s> TlsSocket<'s> {
Ok(()) Ok(())
} }
// Generic inner send method for buffer IO, through TCP socket
// Usage: Push a slice representation of ONE TLS packet
// This function will only increment sequence number by 1
// Repeatedly call this function if sending multiple TLS packets is needed
fn send_tls_slice(&self, sockets: &mut SocketSet, slice: &[u8]) -> Result<()> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_send() {
return Err(Error::Illegal);
}
let buffer_size = slice.len();
tcp_socket.send_slice(slice)
.and_then(
|size| if size == buffer_size {
Ok(())
} else {
Err(Error::Truncated)
}
)?;
self.session.borrow_mut().increment_client_sequence_number();
Ok(())
}
// Send method for TLS Handshake that needs to be encrypted. // Send method for TLS Handshake that needs to be encrypted.
// Does the following things: // Does the following things:
// 1. Encryption // 1. Encryption
@ -894,34 +914,6 @@ impl<'s> TlsSocket<'s> {
Ok(()) Ok(())
} }
// Generic inner recv method, through TCP socket
// A TCP packet can contain multiple TLS records (including 0)
// Therefore, sequence nubmer incrementation is not completed here
fn recv_tls_repr<'a>(&'a self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<Vec<(&[u8], TlsRepr)>> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_recv() {
return Ok(Vec::new());
}
let array_size = tcp_socket.recv_slice(byte_array)?;
let mut vec: Vec<(&[u8], TlsRepr)> = Vec::new();
let mut bytes: &[u8] = &byte_array[..array_size];
loop {
match parse_tls_repr(bytes) {
Ok((rest, (repr_slice, repr))) => {
vec.push(
(repr_slice, repr)
);
if rest.len() == 0 {
return Ok(vec);
} else {
bytes = rest;
}
},
_ => return Err(Error::Unrecognized),
};
}
}
pub fn recv_slice(&self, sockets: &mut SocketSet, data: &mut [u8]) -> Result<usize> { pub fn recv_slice(&self, sockets: &mut SocketSet, data: &mut [u8]) -> Result<usize> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_recv() { if !tcp_socket.can_recv() {
@ -937,12 +929,6 @@ impl<'s> TlsSocket<'s> {
return Ok(0); return Ok(0);
} }
// TODO: Use `recv` to receive instead
// Issue with using recv slice:
// Encrypted application data can cramp together into a TCP Segment
// Dequeuing all bytes from the buffer immediately can cause
// 1. Incorrect decryption, hence throwing error, and
// 2. sequence number to go out of sync forever
let (recv_slice_size, acceptable) = tcp_socket.recv( let (recv_slice_size, acceptable) = tcp_socket.recv(
|buffer| { |buffer| {
// Read the size of the TLS record beforehand // Read the size of the TLS record beforehand
@ -967,8 +953,6 @@ impl<'s> TlsSocket<'s> {
return Ok(0); return Ok(0);
} }
// let recv_slice_size = tcp_socket.recv_slice(data)?;
// Encrypted data need a TLS record wrapper (5 bytes) // Encrypted data need a TLS record wrapper (5 bytes)
// Authentication tag (16 bytes, for all supported AEADs) // Authentication tag (16 bytes, for all supported AEADs)
// Content type byte (1 byte) // Content type byte (1 byte)
@ -980,7 +964,6 @@ impl<'s> TlsSocket<'s> {
// Get Associated Data // Get Associated Data
let mut associated_data: [u8; 5] = [0; 5]; let mut associated_data: [u8; 5] = [0; 5];
associated_data.clone_from_slice(&data[..5]); associated_data.clone_from_slice(&data[..5]);
// log::info!("Received encrypted appdata: {:?}", &data[..recv_slice_size]);
// Dump association data (TLS Record wrapper) // Dump association data (TLS Record wrapper)
// Only decrypt application data // Only decrypt application data
@ -998,7 +981,7 @@ impl<'s> TlsSocket<'s> {
// If it is not application data, handle it internally // If it is not application data, handle it internally
if content_type != TlsContentType::ApplicationData { if content_type != TlsContentType::ApplicationData {
// TODO:: Implement key update // TODO: Implement key update here, as it could be a key update
log::info!("Other decrypted: {:?}", &data[..(recv_slice_size-16)]); log::info!("Other decrypted: {:?}", &data[..(recv_slice_size-16)]);
return Ok(0); return Ok(0);
} }
@ -1021,12 +1004,12 @@ impl<'s> TlsSocket<'s> {
} }
pub fn send_slice(&self, sockets: &mut SocketSet, data: &[u8]) -> Result<()> { pub fn send_slice(&self, sockets: &mut SocketSet, data: &[u8]) -> Result<()> {
// If the handshake is not completed, do not push bytes onto the buffer // If the handshake is not completed, do not push bytes onto the buffer
// through TlsSocket.send_slice() // through TlsSocket.send_slice()
// Handshake send should be through TCPSocket directly. // Handshake send should be through TCPSocket directly.
let mut session = self.session.borrow_mut();
if session.get_tls_state() != TlsState::CONNECTED { if session.get_tls_state() != TlsState::CONNECTED {
return Ok(0); return Ok(());
} }
// Sending order: // Sending order:
@ -1048,7 +1031,6 @@ impl<'s> TlsSocket<'s> {
let mut vec: HeaplessVec<u8, U1024> = HeaplessVec::from_slice(data).unwrap(); let mut vec: HeaplessVec<u8, U1024> = HeaplessVec::from_slice(data).unwrap();
vec.push(0x17).unwrap(); // Content type vec.push(0x17).unwrap(); // Content type
let mut session = self.session.borrow_mut();
let tag = session.encrypt_application_data_in_place_detached( let tag = session.encrypt_application_data_in_place_detached(
&associated_data, &associated_data,
&mut vec &mut vec