diff --git a/src/lib.rs b/src/lib.rs index 8aa512c..2ff7a4d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,7 +46,6 @@ use net::iface::EthernetInterface; use net::time::Instant; use net::phy::Device; -use crate::tls::TlsSocket; use crate::set::TlsSocketSet; // One-call function for polling all sockets within socket set diff --git a/src/session.rs b/src/session.rs index 3e999e4..02cb704 100644 --- a/src/session.rs +++ b/src/session.rs @@ -149,8 +149,6 @@ impl<'a> Session<'a> { } // State transition from WAIT_SH to WAIT_EE - // TODO: Memory allocation - // It current dumps too much memory onto the stack on invocation pub(crate) fn client_update_for_sh( &mut self, cipher_suite: CipherSuite, diff --git a/src/set.rs b/src/set.rs index 89a0c80..767754a 100644 --- a/src/set.rs +++ b/src/set.rs @@ -2,21 +2,7 @@ use smoltcp as net; use managed::ManagedSlice; use crate::tls::TlsSocket; -use net::socket::SocketSetItem; use net::socket::SocketSet; -use net::socket::SocketHandle; -use net::socket::Socket; -use net::socket::TcpSocket; -use net::socket::AnySocket; -use net::socket::SocketRef; -use net::iface::EthernetInterface; -use net::time::Instant; -use net::phy::Device; - -use core::convert::From; -use core::cell::RefCell; - -use alloc::vec::Vec; pub struct TlsSocketSet<'a> { tls_sockets: ManagedSlice<'a, Option>> diff --git a/src/tls.rs b/src/tls.rs index a169bbc..4ce8add 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -3,13 +3,9 @@ use smoltcp::socket::TcpState; use smoltcp::socket::SocketHandle; use smoltcp::socket::SocketSet; use smoltcp::socket::TcpSocketBuffer; -use smoltcp::socket::SocketRef; use smoltcp::wire::IpEndpoint; use smoltcp::Result; use smoltcp::Error; -use smoltcp::iface::EthernetInterface; -use smoltcp::time::Instant; -use smoltcp::phy::Device; use byteorder::{ByteOrder, NetworkEndian}; use generic_array::GenericArray; @@ -18,7 +14,6 @@ use core::convert::TryFrom; use core::convert::TryInto; use core::cell::RefCell; -use rand_core::{RngCore, CryptoRng}; use p256::{EncodedPoint, ecdh::EphemeralSecret}; use ccm::consts::*; @@ -133,7 +128,7 @@ impl<'s> TlsSocket<'s> { // Check TCP socket/ TLS session { let mut tcp_socket = sockets.get::(self.tcp_handle); - let mut tls_socket = self.session.borrow(); + let tls_socket = self.session.borrow(); // Check if it should connect to client or not if tls_socket.get_session_role() != crate::session::TlsRole::Client { @@ -178,22 +173,34 @@ impl<'s> TlsSocket<'s> { let repr = TlsRepr::new() .client_hello(&ecdh_secret, &x25519_secret, random, session_id.clone()); - // Update hash function with client hello handshake - let mut array = [0; 512]; - let mut buffer = TlsBuffer::new(&mut array); - buffer.enqueue_tls_repr(repr)?; - let slice: &[u8] = buffer.into(); + { + let mut tcp_socket = sockets.get::(self.tcp_handle); + tcp_socket.send( + |data| { + // Enqueue tls representation without extra allocation + let mut buffer = TlsBuffer::new(data); + if buffer.enqueue_tls_repr(repr).is_err() { + return (0, ()) + } + let slice: &[u8] = buffer.into(); - // Send the packet - self.send_tls_slice(sockets, slice)?; + // Update the session + // No sequence number calculation in CH + // because there is no encryption + // Still, data needs to be hashed + let mut session = self.session.borrow_mut(); + session.client_update_for_ch( + ecdh_secret, + x25519_secret, + session_id, + &slice[5..] + ); - // Update TLS session - self.session.borrow_mut().client_update_for_ch( - ecdh_secret, - x25519_secret, - session_id, - &slice[5..] - ); + // Finally send the data + (slice.len(), ()) + } + )?; + } }, // TLS Client wait for Server Hello @@ -389,78 +396,113 @@ impl<'s> TlsSocket<'s> { } } + + // Read for TLS packet // Proposition: Decouple all data from TLS record layer before processing // Recouple a brand new TLS record wrapper - let mut array: [u8; 2048] = [0; 2048]; - let mut tls_repr_vec = self.recv_tls_repr(sockets, &mut array)?; + // Use recv to avoid buffer allocation + { + let mut tcp_socket = sockets.get::(self.tcp_handle); + tcp_socket.recv( + |buffer| { + let buffer_size = buffer.len(); + + let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new(); + let mut bytes = &buffer[..buffer_size]; - // Take the TLS representation out of the vector, - // Process as a queue - let tls_repr_vec_size = tls_repr_vec.len(); - for _index in 0..tls_repr_vec_size { - let (repr_slice, mut repr) = tls_repr_vec.remove(0); + // Sequentially push reprs into vec + loop { + match parse_tls_repr(bytes) { + Ok((rest, (repr_slice, repr))) => { + tls_repr_vec.push( + (repr_slice, repr) + ); + if rest.len() == 0 { + break; + } else { + bytes = rest; + } + }, + // Dequeue everything and abort processing if it is malformed + _ => return (buffer_size, ()) + }; + } - // Process record base on content type - log::info!("Record type: {:?}", repr.content_type); - if repr.content_type == TlsContentType::ApplicationData { - log::info!("Found application data"); - // Take the payload out of TLS Record and decrypt - let mut app_data = repr.payload.take().unwrap(); - let mut associated_data = [0; 5]; - associated_data[0] = repr.content_type.into(); - NetworkEndian::write_u16( - &mut associated_data[1..3], - repr.version.into() - ); - NetworkEndian::write_u16( - &mut associated_data[3..5], - repr.length - ); - { - let mut session = self.session.borrow_mut(); - session.decrypt_in_place_detached( - &associated_data, - &mut app_data - ).unwrap(); - session.increment_server_sequence_number(); - } + // Sequencially process the representations in vector + // Decrypt and split the handshake if necessary + let tls_repr_vec_size = tls_repr_vec.len(); + for _index in 0..tls_repr_vec_size { + let (repr_slice, mut repr) = tls_repr_vec.remove(0); - // Discard last 16 bytes (auth tag) - let inner_plaintext = &app_data[..app_data.len()-16]; - let (inner_content_type, _) = get_content_type_inner_plaintext( - inner_plaintext - ); - if inner_content_type != TlsContentType::Handshake { - // Silently ignore non-handshakes - continue; - } - let (_, mut inner_handshakes) = complete( - parse_inner_plaintext_for_handshake - )(inner_plaintext).unwrap(); + // Process record base on content type + log::info!("Record type: {:?}", repr.content_type); - // Sequentially process all handshakes - let num_of_handshakes = inner_handshakes.len(); - for _ in 0..num_of_handshakes { - let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); - self.process( - handshake_slice, - TlsRepr { - content_type: TlsContentType::Handshake, - version: repr.version, - length: u16::try_from(handshake_repr.length).unwrap() + 4, - payload: None, - handshake: Some(handshake_repr) + if repr.content_type == TlsContentType::ApplicationData { + log::info!("Found application data"); + // Take the payload out of TLS Record and decrypt + let mut app_data = repr.payload.take().unwrap(); + let mut associated_data = [0; 5]; + associated_data[0] = repr.content_type.into(); + NetworkEndian::write_u16( + &mut associated_data[1..3], + repr.version.into() + ); + NetworkEndian::write_u16( + &mut associated_data[3..5], + repr.length + ); + { + let mut session = self.session.borrow_mut(); + session.decrypt_in_place_detached( + &associated_data, + &mut app_data + ).unwrap(); + session.increment_server_sequence_number(); + } + + // Discard last 16 bytes (auth tag) + let inner_plaintext = &app_data[..app_data.len()-16]; + let (inner_content_type, _) = get_content_type_inner_plaintext( + inner_plaintext + ); + if inner_content_type != TlsContentType::Handshake { + // Silently ignore non-handshakes + continue; + } + let (_, mut inner_handshakes) = complete( + parse_inner_plaintext_for_handshake + )(inner_plaintext).unwrap(); + + // Sequentially process all handshakes + let num_of_handshakes = inner_handshakes.len(); + for _ in 0..num_of_handshakes { + let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); + if self.process( + handshake_slice, + TlsRepr { + content_type: TlsContentType::Handshake, + version: repr.version, + length: u16::try_from(handshake_repr.length).unwrap() + 4, + payload: None, + handshake: Some(handshake_repr) + } + ).is_err() { + return (buffer_size, ()) + } + } } - )?; + + else { + if self.process(repr_slice, repr).is_err() { + return (buffer_size, ()) + } + log::info!("Processed record"); + } + } + (buffer_size, ()) } - } - - - else { - self.process(repr_slice, repr)?; - log::info!("Processed record"); - } + )?; } Ok(self.session.borrow().has_completed_handshake()) @@ -826,28 +868,6 @@ impl<'s> TlsSocket<'s> { Ok(()) } - // Generic inner send method for buffer IO, through TCP socket - // Usage: Push a slice representation of ONE TLS packet - // This function will only increment sequence number by 1 - // Repeatedly call this function if sending multiple TLS packets is needed - fn send_tls_slice(&self, sockets: &mut SocketSet, slice: &[u8]) -> Result<()> { - let mut tcp_socket = sockets.get::(self.tcp_handle); - if !tcp_socket.can_send() { - return Err(Error::Illegal); - } - let buffer_size = slice.len(); - tcp_socket.send_slice(slice) - .and_then( - |size| if size == buffer_size { - Ok(()) - } else { - Err(Error::Truncated) - } - )?; - self.session.borrow_mut().increment_client_sequence_number(); - Ok(()) - } - // Send method for TLS Handshake that needs to be encrypted. // Does the following things: // 1. Encryption @@ -894,34 +914,6 @@ impl<'s> TlsSocket<'s> { Ok(()) } - // Generic inner recv method, through TCP socket - // A TCP packet can contain multiple TLS records (including 0) - // Therefore, sequence nubmer incrementation is not completed here - fn recv_tls_repr<'a>(&'a self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { - let mut tcp_socket = sockets.get::(self.tcp_handle); - if !tcp_socket.can_recv() { - return Ok(Vec::new()); - } - let array_size = tcp_socket.recv_slice(byte_array)?; - let mut vec: Vec<(&[u8], TlsRepr)> = Vec::new(); - let mut bytes: &[u8] = &byte_array[..array_size]; - loop { - match parse_tls_repr(bytes) { - Ok((rest, (repr_slice, repr))) => { - vec.push( - (repr_slice, repr) - ); - if rest.len() == 0 { - return Ok(vec); - } else { - bytes = rest; - } - }, - _ => return Err(Error::Unrecognized), - }; - } - } - pub fn recv_slice(&self, sockets: &mut SocketSet, data: &mut [u8]) -> Result { let mut tcp_socket = sockets.get::(self.tcp_handle); if !tcp_socket.can_recv() { @@ -937,12 +929,6 @@ impl<'s> TlsSocket<'s> { return Ok(0); } - // TODO: Use `recv` to receive instead - // Issue with using recv slice: - // Encrypted application data can cramp together into a TCP Segment - // Dequeuing all bytes from the buffer immediately can cause - // 1. Incorrect decryption, hence throwing error, and - // 2. sequence number to go out of sync forever let (recv_slice_size, acceptable) = tcp_socket.recv( |buffer| { // Read the size of the TLS record beforehand @@ -967,8 +953,6 @@ impl<'s> TlsSocket<'s> { return Ok(0); } - // let recv_slice_size = tcp_socket.recv_slice(data)?; - // Encrypted data need a TLS record wrapper (5 bytes) // Authentication tag (16 bytes, for all supported AEADs) // Content type byte (1 byte) @@ -980,7 +964,6 @@ impl<'s> TlsSocket<'s> { // Get Associated Data let mut associated_data: [u8; 5] = [0; 5]; associated_data.clone_from_slice(&data[..5]); - // log::info!("Received encrypted appdata: {:?}", &data[..recv_slice_size]); // Dump association data (TLS Record wrapper) // Only decrypt application data @@ -998,7 +981,7 @@ impl<'s> TlsSocket<'s> { // If it is not application data, handle it internally if content_type != TlsContentType::ApplicationData { - // TODO:: Implement key update + // TODO: Implement key update here, as it could be a key update log::info!("Other decrypted: {:?}", &data[..(recv_slice_size-16)]); return Ok(0); } @@ -1021,12 +1004,12 @@ impl<'s> TlsSocket<'s> { } pub fn send_slice(&self, sockets: &mut SocketSet, data: &[u8]) -> Result<()> { - // If the handshake is not completed, do not push bytes onto the buffer // through TlsSocket.send_slice() // Handshake send should be through TCPSocket directly. + let mut session = self.session.borrow_mut(); if session.get_tls_state() != TlsState::CONNECTED { - return Ok(0); + return Ok(()); } // Sending order: @@ -1048,7 +1031,6 @@ impl<'s> TlsSocket<'s> { let mut vec: HeaplessVec = HeaplessVec::from_slice(data).unwrap(); vec.push(0x17).unwrap(); // Content type - let mut session = self.session.borrow_mut(); let tag = session.encrypt_application_data_in_place_detached( &associated_data, &mut vec