diff --git a/src/lib.rs b/src/lib.rs index 2ff7a4d..8bece10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,6 @@ pub fn poll( where DeviceT: for<'d> Device<'d> { - tls_sockets.polled_by(sockets)?; + tls_sockets.polled_by(sockets, iface, now)?; iface.poll(sockets, now).map_err(Error::PropagatedError) } diff --git a/src/main.rs b/src/main.rs index 51edb47..ed38fec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -66,10 +66,9 @@ fn main() { let mut tls_socket = unsafe { let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]); let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]); + let tcp_socket = smoltcp::socket::TcpSocket::new(rx_buffer, tx_buffer); TlsSocket::new( - &mut sockets, - rx_buffer, - tx_buffer, + tcp_socket, &mut RNG, None ) diff --git a/src/set.rs b/src/set.rs index 767754a..f572d22 100644 --- a/src/set.rs +++ b/src/set.rs @@ -3,25 +3,28 @@ use smoltcp as net; use managed::ManagedSlice; use crate::tls::TlsSocket; use net::socket::SocketSet; +use net::phy::Device; +use net::iface::EthernetInterface; +use net::time::Instant; -pub struct TlsSocketSet<'a> { - tls_sockets: ManagedSlice<'a, Option>> +pub struct TlsSocketSet<'a, 'b, 'c> { + tls_sockets: ManagedSlice<'a, Option>> } #[derive(Clone, Copy, Debug)] pub struct TlsSocketHandle(usize); -impl<'a> TlsSocketSet<'a> { +impl<'a, 'b, 'c> TlsSocketSet<'a, 'b, 'c> { pub fn new(tls_sockets: T) -> Self where - T: Into>>> + T: Into>>> { Self { tls_sockets: tls_sockets.into() } } - pub fn add(&mut self, socket: TlsSocket<'a>) -> TlsSocketHandle + pub fn add(&mut self, socket: TlsSocket<'a, 'b, 'c>) -> TlsSocketHandle { for (index, slot) in self.tls_sockets.iter_mut().enumerate() { if slot.is_none() { @@ -43,20 +46,24 @@ impl<'a> TlsSocketSet<'a> { } } - pub fn get(&mut self, handle: TlsSocketHandle) -> &mut TlsSocket<'a> { + pub fn get(&mut self, handle: TlsSocketHandle) -> &mut TlsSocket<'a, 'b, 'c> { self.tls_sockets[handle.0].as_mut().unwrap() } - pub(crate) fn polled_by( + pub(crate) fn polled_by( &mut self, - sockets: &mut SocketSet + sockets: &mut SocketSet, + iface: &mut EthernetInterface, + now: Instant ) -> smoltcp::Result + where + DeviceT: for<'d> Device<'d> { for socket in self.tls_sockets.iter_mut() { if socket.is_some() { socket.as_mut() .unwrap() - .update_handshake(sockets)?; + .update_handshake(iface, now)?; } } diff --git a/src/tls.rs b/src/tls.rs index e9f3a92..8897af1 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -6,6 +6,9 @@ use smoltcp::socket::TcpSocketBuffer; use smoltcp::wire::IpEndpoint; use smoltcp::Result; use smoltcp::Error; +use smoltcp::phy::Device; +use smoltcp::iface::EthernetInterface; +use smoltcp::time::Instant; use byteorder::{ByteOrder, NetworkEndian}; use generic_array::GenericArray; @@ -57,30 +60,30 @@ pub(crate) enum TlsState { SERVER_CONNECTED, } -pub struct TlsSocket<'b> +pub struct TlsSocket<'a, 'b, 'c> { + // Locally owned SocketSet, solely containing 1 TCP socket + sockets: SocketSet<'a, 'b, 'c>, tcp_handle: SocketHandle, rng: &'b mut dyn crate::TlsRng, session: RefCell>, } -impl<'b> TlsSocket<'b> { - pub fn new<'a, 'c>( - sockets: &mut SocketSet<'a, 'b, 'c>, - rx_buffer: TcpSocketBuffer<'b>, - tx_buffer: TcpSocketBuffer<'b>, +impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { + pub fn new( + tcp_socket: TcpSocket<'b>, rng: &'b mut dyn crate::TlsRng, certificate_with_key: Option<( crate::session::CertificatePrivateKey, Vec<&'b [u8]> )> ) -> Self - where - 'b: 'c, { - let tcp_socket = TcpSocket::new(rx_buffer, tx_buffer); + let socket_set_entries: [_; 1] = Default::default(); + let mut sockets = SocketSet::new(socket_set_entries); let tcp_handle = sockets.add(tcp_socket); TlsSocket { + sockets, tcp_handle, rng, session: RefCell::new( @@ -89,26 +92,8 @@ impl<'b> TlsSocket<'b> { } } - // pub fn from_tcp_handle( - // tcp_handle: SocketHandle, - // rng: &'s mut dyn crate::TlsRng, - // certificate_with_key: Option<( - // crate::session::CertificatePrivateKey, - // Vec<&'s [u8]> - // )> - // ) -> Self { - // TlsSocket { - // tcp_handle, - // rng, - // session: RefCell::new( - // Session::new(TlsRole::Client, certificate_with_key) - // ), - // } - // } - pub fn connect( &mut self, - sockets: &mut SocketSet, remote_endpoint: T, local_endpoint: U, ) -> Result<()> @@ -117,7 +102,7 @@ impl<'b> TlsSocket<'b> { U: Into, { // Start TCP handshake - let mut tcp_socket = sockets.get::(self.tcp_handle); + let mut tcp_socket = self.sockets.get::(self.tcp_handle); tcp_socket.connect(remote_endpoint, local_endpoint)?; // Permit TLS handshake as well @@ -131,14 +116,13 @@ impl<'b> TlsSocket<'b> { pub fn listen( &mut self, - sockets: &mut SocketSet, local_endpoint: T ) -> Result<()> where T: Into { // Listen from TCP socket - let mut tcp_socket = sockets.get::(self.tcp_handle); + let mut tcp_socket = self.sockets.get::(self.tcp_handle); tcp_socket.listen(local_endpoint)?; // Update tls session to server_start @@ -148,7 +132,17 @@ impl<'b> TlsSocket<'b> { Ok(()) } - pub fn update_handshake(&mut self, sockets: &mut SocketSet) -> Result { + pub fn update_handshake( + &mut self, + iface: &mut EthernetInterface, + now: Instant + ) -> Result + where + DeviceT: for<'d> Device<'d> + { + // Poll the TCP socket, no matter what + iface.poll(&mut self.sockets, now)?; + // Handle TLS handshake through TLS states let tls_state = { self.session.borrow().get_tls_state() @@ -156,8 +150,7 @@ impl<'b> TlsSocket<'b> { // Check TCP socket/ TLS session { - let mut tcp_socket = sockets.get::(self.tcp_handle); - let tls_socket = self.session.borrow(); + let tcp_state = self.sockets.get::(self.tcp_handle).state(); // // Check if it should connect to client or not // if tls_socket.get_session_role() != crate::session::TlsRole::Client { @@ -168,21 +161,20 @@ impl<'b> TlsSocket<'b> { // Skip handshake processing if it is already completed // However, redo TCP handshake if TLS socket is trying to connect and // TCP socket is not connected - if tcp_socket.state() != TcpState::Established { + if tcp_state != TcpState::Established { if tls_state == TlsState::CLIENT_START { // Restart TCP handshake is it is closed for some reason + let mut tcp_socket = self.sockets.get::(self.tcp_handle); + let session = self.session.borrow(); if !tcp_socket.is_open() { tcp_socket.connect( - tls_socket.get_remote_endpoint(), - tls_socket.get_local_endpoint() + session.get_remote_endpoint(), + session.get_local_endpoint() )?; } - return Ok(false); - } else { - // Do nothing, either handshake failed or the socket closed - // after finishing the handshake - return Ok(false); } + // Terminate the procedure, as no processing is necessary + return Ok(false); } } // Handle TLS handshake through TLS states @@ -204,7 +196,8 @@ impl<'b> TlsSocket<'b> { .client_hello(&ecdh_secret, &x25519_secret, random, session_id.clone()); { - let mut tcp_socket = sockets.get::(self.tcp_handle); + let mut tcp_socket = self.sockets.get::(self.tcp_handle); + let mut session = self.session.borrow_mut(); tcp_socket.send( |data| { // Enqueue tls representation without extra allocation @@ -218,7 +211,6 @@ impl<'b> TlsSocket<'b> { // No sequence number calculation in CH // because there is no encryption // Still, data needs to be hashed - let mut session = self.session.borrow_mut(); session.client_update_for_ch( ecdh_secret, x25519_secret, @@ -331,7 +323,7 @@ impl<'b> TlsSocket<'b> { (certificates_total_length, buffer_vec) }; - self.send_application_slice(sockets, &mut buffer_vec.clone())?; + self.send_application_slice(&mut buffer_vec.clone())?; // Update session let buffer_vec_length = buffer_vec.len(); @@ -388,7 +380,7 @@ impl<'b> TlsSocket<'b> { // Push content byte (handshake: 22) verify_buffer_vec.push(22); - self.send_application_slice(sockets, &mut verify_buffer_vec.clone())?; + self.send_application_slice(&mut verify_buffer_vec.clone())?; // Update session let cert_verify_len = verify_buffer_vec.len(); @@ -416,7 +408,7 @@ impl<'b> TlsSocket<'b> { buffer.push(22).unwrap(); buffer }; - self.send_application_slice(sockets, &mut inner_plaintext.clone())?; + self.send_application_slice(&mut inner_plaintext.clone())?; let inner_plaintext_length = inner_plaintext.len(); self.session.borrow_mut() @@ -476,7 +468,7 @@ impl<'b> TlsSocket<'b> { ecdhe_public_key ); { - let mut tcp_socket = sockets.get::(self.tcp_handle); + let mut tcp_socket = self.sockets.get::(self.tcp_handle); let mut session = self.session.borrow_mut(); tcp_socket.send( |data| { @@ -499,6 +491,8 @@ impl<'b> TlsSocket<'b> { )?; } + log::info!("sent server hello"); + // Construct and send minimalistic EE let inner_plaintext: [u8; 7] = [ 0x08, // EE type @@ -506,7 +500,7 @@ impl<'b> TlsSocket<'b> { 0x00, 0x00, // Length of extensions: 0 22 // Content type of InnerPlainText ]; - self.send_application_slice(sockets, &mut inner_plaintext.clone())?; + self.send_application_slice(&mut inner_plaintext.clone())?; let inner_plaintext_length = inner_plaintext.len(); { @@ -516,6 +510,8 @@ impl<'b> TlsSocket<'b> { ); } + log::info!("sent encrypted extension"); + // TODO: Option to allow a certificate request // Construct and send server certificate handshake content @@ -579,13 +575,14 @@ impl<'b> TlsSocket<'b> { inner_plaintext }; - self.send_application_slice(sockets, &mut inner_plaintext.clone())?; + self.send_application_slice(&mut inner_plaintext.clone())?; let inner_plaintext_length = inner_plaintext.len(); // Update session { self.session.borrow_mut() .server_update_for_sent_certificate(&inner_plaintext[..(inner_plaintext_length-1)]); } + log::info!("sent certificate"); // Construct and send certificate verify let mut inner_plaintext = { @@ -619,10 +616,7 @@ impl<'b> TlsSocket<'b> { inner_plaintext }; - self.send_application_slice( - sockets, - &mut inner_plaintext.clone() - )?; + self.send_application_slice(&mut inner_plaintext.clone())?; let inner_plaintext_length = inner_plaintext.len(); { @@ -631,6 +625,7 @@ impl<'b> TlsSocket<'b> { &inner_plaintext[..(inner_plaintext_length-1)] ); } + log::info!("sent certificate verify"); // Construct and send server finished let inner_plaintext: HeaplessVec = { @@ -647,13 +642,14 @@ impl<'b> TlsSocket<'b> { buffer.push(22).unwrap(); buffer }; - self.send_application_slice(sockets, &mut inner_plaintext.clone())?; + self.send_application_slice(&mut inner_plaintext.clone())?; let inner_plaintext_length = inner_plaintext.len(); { self.session.borrow_mut() .server_update_for_server_finished(&inner_plaintext[..(inner_plaintext_length-1)]); } + log::info!("sent client finished"); } // There is no need to care about handshake if it was completed @@ -669,116 +665,215 @@ impl<'b> TlsSocket<'b> { // Read for TLS packet // Proposition: Decouple all data from TLS record layer before processing // Recouple a brand new TLS record wrapper - // Use recv to avoid buffer allocation + // Use peek & recv to avoid buffer allocation { - let mut tcp_socket = sockets.get::(self.tcp_handle); - tcp_socket.recv( - |buffer| { - // log::info!("Received Buffer: {:?}", buffer); - let buffer_size = buffer.len(); + let tls_repr_vec = { + let mut tcp_socket = self.sockets.get::(self.tcp_handle); - // Provide a way to end the process early - if buffer_size == 0 { - return (0, ()) - } - - log::info!("Received something"); - - let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new(); - let mut bytes = &buffer[..buffer_size]; - - // Sequentially push reprs into vec - loop { - match parse_tls_repr(bytes) { - Ok((rest, (repr_slice, repr))) => { - tls_repr_vec.push( - (repr_slice, repr) - ); - if rest.len() == 0 { - break; - } else { - bytes = rest; - } - }, - // Dequeue everything and abort processing if it is malformed - _ => return (buffer_size, ()) - }; - } - - // Sequencially process the representations in vector - // Decrypt and split the handshake if necessary - let tls_repr_vec_size = tls_repr_vec.len(); - for _index in 0..tls_repr_vec_size { - let (repr_slice, mut repr) = tls_repr_vec.remove(0); - - // Process record base on content type - log::info!("Record type: {:?}", repr.content_type); - - if repr.content_type == TlsContentType::ApplicationData { - log::info!("Found application data"); - // Take the payload out of TLS Record and decrypt - let mut app_data = repr.payload.take().unwrap(); - let mut associated_data = [0; 5]; - associated_data[0] = repr.content_type.into(); - NetworkEndian::write_u16( - &mut associated_data[1..3], - repr.version.into() - ); - NetworkEndian::write_u16( - &mut associated_data[3..5], - repr.length - ); - { - let mut session = self.session.borrow_mut(); - session.decrypt_in_place_detached( - &associated_data, - &mut app_data - ).unwrap(); - session.increment_remote_sequence_number(); - } - - // Discard last 16 bytes (auth tag) - let inner_plaintext = &app_data[..app_data.len()-16]; - let (inner_content_type, _) = get_content_type_inner_plaintext( - inner_plaintext - ); - if inner_content_type != TlsContentType::Handshake { - // Silently ignore non-handshakes - continue; - } - let (_, mut inner_handshakes) = complete( - parse_inner_plaintext_for_handshake - )(inner_plaintext).unwrap(); - - // Sequentially process all handshakes - let num_of_handshakes = inner_handshakes.len(); - for _ in 0..num_of_handshakes { - let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); - if self.process( - handshake_slice, - TlsRepr { - content_type: TlsContentType::Handshake, - version: repr.version, - length: u16::try_from(handshake_repr.length).unwrap() + 4, - payload: None, - handshake: Some(handshake_repr) - } - ).is_err() { - return (buffer_size, ()) - } - } - } - - else { - if self.process(repr_slice, repr).is_err() { - return (buffer_size, ()) - } - log::info!("Processed record"); - } - } - (buffer_size, ()) + // Check if there are bytes enqueued in the recv buffer + // No need to do further dequeuing if there are no receivable bytes + if !tcp_socket.can_recv() { + return Ok(self.session.borrow().has_completed_handshake()) } - )?; + + log::info!("Tls can recv"); + + // Peak into the first 5 bytes (TLS record layer) + // This tells the length of the entire record + let length = match tcp_socket.peek(5) { + Ok(bytes) => NetworkEndian::read_u16(&bytes[3..5]), + _ => return Ok(self.session.borrow().has_completed_handshake()) + }; + + log::info!("tls record length: {:?}", length); + + // Recv the entire TLS record + tcp_socket.recv( + |buffer| ((length + 5).into(), Vec::from(&buffer[..(length + 5).into()])) + ).unwrap() + }; + + // Parse the bytes representation of a TLS record + let (repr_slice, mut repr) = match parse_tls_repr(&tls_repr_vec) { + Ok((_, (repr_slice, repr))) => (repr_slice, repr), + _ => return Ok(self.session.borrow().has_completed_handshake()) + }; + + // Process record base on content type + log::info!("Record type: {:?}", repr.content_type); + + if repr.content_type == TlsContentType::ApplicationData { + log::info!("Found application data"); + // Take the payload out of TLS Record and decrypt + let mut app_data = repr.payload.take().unwrap(); + let mut associated_data = [0; 5]; + associated_data[0] = repr.content_type.into(); + NetworkEndian::write_u16( + &mut associated_data[1..3], + repr.version.into() + ); + NetworkEndian::write_u16( + &mut associated_data[3..5], + repr.length + ); + { + let mut session = self.session.borrow_mut(); + session.decrypt_in_place_detached( + &associated_data, + &mut app_data + ).unwrap(); + session.increment_remote_sequence_number(); + } + + // Discard last 16 bytes (auth tag) + let inner_plaintext = &app_data[..app_data.len()-16]; + let (inner_content_type, _) = get_content_type_inner_plaintext( + inner_plaintext + ); + if inner_content_type != TlsContentType::Handshake { + // Silently ignore non-handshakes + return Ok(self.session.borrow().has_completed_handshake()) + } + let (_, mut inner_handshakes) = complete( + parse_inner_plaintext_for_handshake + )(inner_plaintext).unwrap(); + + // Sequentially process all handshakes + let num_of_handshakes = inner_handshakes.len(); + for _ in 0..num_of_handshakes { + let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); + if self.process( + handshake_slice, + TlsRepr { + content_type: TlsContentType::Handshake, + version: repr.version, + length: u16::try_from(handshake_repr.length).unwrap() + 4, + payload: None, + handshake: Some(handshake_repr) + } + ).is_err() { + return Ok(self.session.borrow().has_completed_handshake()) + } + } + } else { + if self.process(repr_slice, repr).is_err() { + return Ok(self.session.borrow().has_completed_handshake()) + } + log::info!("Processed record"); + } + + // // Finally dequeue the record from buffer + // if tcp_socket.recv(|_| (length.into(), ())).is_err() { + // return Ok(self.session.borrow().has_completed_handshake()) + // } + + // tcp_socket.recv( + // |buffer| { + // // log::info!("Received Buffer: {:?}", buffer); + // let buffer_size = buffer.len(); + + // // Provide a way to end the process early + // if buffer_size == 0 { + // return (0, ()) + // } + + // log::info!("Received something"); + + // let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new(); + // let mut bytes = &buffer[..buffer_size]; + + // // Sequentially push reprs into vec + // loop { + // match parse_tls_repr(bytes) { + // Ok((rest, (repr_slice, repr))) => { + // tls_repr_vec.push( + // (repr_slice, repr) + // ); + // if rest.len() == 0 { + // break; + // } else { + // bytes = rest; + // } + // }, + // // Dequeue everything and abort processing if it is malformed + // _ => return (buffer_size, ()) + // }; + // } + + // // Sequencially process the representations in vector + // // Decrypt and split the handshake if necessary + // let tls_repr_vec_size = tls_repr_vec.len(); + // for _index in 0..tls_repr_vec_size { + // let (repr_slice, mut repr) = tls_repr_vec.remove(0); + + // // Process record base on content type + // log::info!("Record type: {:?}", repr.content_type); + + // if repr.content_type == TlsContentType::ApplicationData { + // log::info!("Found application data"); + // // Take the payload out of TLS Record and decrypt + // let mut app_data = repr.payload.take().unwrap(); + // let mut associated_data = [0; 5]; + // associated_data[0] = repr.content_type.into(); + // NetworkEndian::write_u16( + // &mut associated_data[1..3], + // repr.version.into() + // ); + // NetworkEndian::write_u16( + // &mut associated_data[3..5], + // repr.length + // ); + // { + // let mut session = self.session.borrow_mut(); + // session.decrypt_in_place_detached( + // &associated_data, + // &mut app_data + // ).unwrap(); + // session.increment_remote_sequence_number(); + // } + + // // Discard last 16 bytes (auth tag) + // let inner_plaintext = &app_data[..app_data.len()-16]; + // let (inner_content_type, _) = get_content_type_inner_plaintext( + // inner_plaintext + // ); + // if inner_content_type != TlsContentType::Handshake { + // // Silently ignore non-handshakes + // continue; + // } + // let (_, mut inner_handshakes) = complete( + // parse_inner_plaintext_for_handshake + // )(inner_plaintext).unwrap(); + + // // Sequentially process all handshakes + // let num_of_handshakes = inner_handshakes.len(); + // for _ in 0..num_of_handshakes { + // let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); + // if self.process( + // handshake_slice, + // TlsRepr { + // content_type: TlsContentType::Handshake, + // version: repr.version, + // length: u16::try_from(handshake_repr.length).unwrap() + 4, + // payload: None, + // handshake: Some(handshake_repr) + // } + // ).is_err() { + // return (buffer_size, ()) + // } + // } + // } + + // else { + // if self.process(repr_slice, repr).is_err() { + // return (buffer_size, ()) + // } + // log::info!("Processed record"); + // } + // } + // (buffer_size, ()) + // } + // )?; } Ok(self.session.borrow().has_completed_handshake()) @@ -1433,8 +1528,8 @@ impl<'b> TlsSocket<'b> { // Input should be inner plaintext // Note: Do not put this slice into the transcript hash. It is polluted. // TODO: Rename this function. It is only good for client finished - fn send_application_slice(&self, sockets: &mut SocketSet, slice: &mut [u8]) -> Result<()> { - let mut tcp_socket = sockets.get::(self.tcp_handle); + fn send_application_slice(&mut self, slice: &mut [u8]) -> Result<()> { + let mut tcp_socket = self.sockets.get::(self.tcp_handle); if !tcp_socket.can_send() { return Err(Error::Illegal); } @@ -1471,8 +1566,8 @@ impl<'b> TlsSocket<'b> { Ok(()) } - pub fn recv_slice(&self, sockets: &mut SocketSet, data: &mut [u8]) -> Result { - let mut tcp_socket = sockets.get::(self.tcp_handle); + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { + let mut tcp_socket = self.sockets.get::(self.tcp_handle); if !tcp_socket.can_recv() { return Ok(0); } @@ -1561,7 +1656,7 @@ impl<'b> TlsSocket<'b> { Ok(actual_application_data_length) } - pub fn send_slice(&self, sockets: &mut SocketSet, data: &[u8]) -> Result<()> { + pub fn send_slice(&mut self, data: &[u8]) -> Result<()> { // If the handshake is not completed, do not push bytes onto the buffer // through TlsSocket.send_slice() // Handshake send should be through TCPSocket directly. @@ -1596,7 +1691,7 @@ impl<'b> TlsSocket<'b> { ).unwrap(); session.increment_local_sequence_number(); - let mut tcp_socket = sockets.get::(self.tcp_handle); + let mut tcp_socket = self.sockets.get::(self.tcp_handle); if !tcp_socket.can_send() { return Err(Error::Illegal); }