diff --git a/src/session.rs b/src/session.rs index 8e1b483..362ad1c 100644 --- a/src/session.rs +++ b/src/session.rs @@ -14,6 +14,7 @@ use smoltcp::wire::IpEndpoint; use crate::tls::TlsState; use crate::tls_packet::CipherSuite; +use crate::tls_packet::AlertType; use crate::key::*; use crate::tls_packet::SignatureScheme; use crate::Error; @@ -143,6 +144,15 @@ impl<'a> Session<'a> { self.state = TlsState::SERVER_START; } + pub(crate) fn invalidate_session( + &mut self, + alert: AlertType, + received_slice: &[u8] + ) { + self.hash.update(received_slice); + self.state = TlsState::NEED_RESET(alert); + } + // State transition from START to WAIT_SH pub(crate) fn client_update_for_ch( &mut self, diff --git a/src/tls.rs b/src/tls.rs index 8897af1..0f9535d 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -58,6 +58,8 @@ pub(crate) enum TlsState { SERVER_WAIT_CV, SERVER_WAIT_FINISHED, SERVER_CONNECTED, + // `Derailed` state should any exceptions occur + NEED_RESET(AlertType) } pub struct TlsSocket<'a, 'b, 'c> @@ -142,7 +144,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { { // Poll the TCP socket, no matter what iface.poll(&mut self.sockets, now)?; - + // Handle TLS handshake through TLS states let tls_state = { self.session.borrow().get_tls_state() @@ -152,12 +154,6 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { { let tcp_state = self.sockets.get::(self.tcp_handle).state(); - // // Check if it should connect to client or not - // if tls_socket.get_session_role() != crate::session::TlsRole::Client { - // // Return true for no need to do anymore handshake - // return Ok(true); - // } - // Skip handshake processing if it is already completed // However, redo TCP handshake if TLS socket is trying to connect and // TCP socket is not connected @@ -658,6 +654,13 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { return Ok(true); } + // Terminate TLS connection with an alert + // Terminate TCP session by issuing `close()` + // Reset the socket + TlsState::NEED_RESET(alert) => { + todo!() + } + // Other states _ => {} } @@ -676,8 +679,6 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { return Ok(self.session.borrow().has_completed_handshake()) } - log::info!("Tls can recv"); - // Peak into the first 5 bytes (TLS record layer) // This tells the length of the entire record let length = match tcp_socket.peek(5) { @@ -685,8 +686,6 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { _ => return Ok(self.session.borrow().has_completed_handshake()) }; - log::info!("tls record length: {:?}", length); - // Recv the entire TLS record tcp_socket.recv( |buffer| ((length + 5).into(), Vec::from(&buffer[..(length + 5).into()])) @@ -761,119 +760,6 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } log::info!("Processed record"); } - - // // Finally dequeue the record from buffer - // if tcp_socket.recv(|_| (length.into(), ())).is_err() { - // return Ok(self.session.borrow().has_completed_handshake()) - // } - - // tcp_socket.recv( - // |buffer| { - // // log::info!("Received Buffer: {:?}", buffer); - // let buffer_size = buffer.len(); - - // // Provide a way to end the process early - // if buffer_size == 0 { - // return (0, ()) - // } - - // log::info!("Received something"); - - // let mut tls_repr_vec: Vec<(&[u8], TlsRepr)> = Vec::new(); - // let mut bytes = &buffer[..buffer_size]; - - // // Sequentially push reprs into vec - // loop { - // match parse_tls_repr(bytes) { - // Ok((rest, (repr_slice, repr))) => { - // tls_repr_vec.push( - // (repr_slice, repr) - // ); - // if rest.len() == 0 { - // break; - // } else { - // bytes = rest; - // } - // }, - // // Dequeue everything and abort processing if it is malformed - // _ => return (buffer_size, ()) - // }; - // } - - // // Sequencially process the representations in vector - // // Decrypt and split the handshake if necessary - // let tls_repr_vec_size = tls_repr_vec.len(); - // for _index in 0..tls_repr_vec_size { - // let (repr_slice, mut repr) = tls_repr_vec.remove(0); - - // // Process record base on content type - // log::info!("Record type: {:?}", repr.content_type); - - // if repr.content_type == TlsContentType::ApplicationData { - // log::info!("Found application data"); - // // Take the payload out of TLS Record and decrypt - // let mut app_data = repr.payload.take().unwrap(); - // let mut associated_data = [0; 5]; - // associated_data[0] = repr.content_type.into(); - // NetworkEndian::write_u16( - // &mut associated_data[1..3], - // repr.version.into() - // ); - // NetworkEndian::write_u16( - // &mut associated_data[3..5], - // repr.length - // ); - // { - // let mut session = self.session.borrow_mut(); - // session.decrypt_in_place_detached( - // &associated_data, - // &mut app_data - // ).unwrap(); - // session.increment_remote_sequence_number(); - // } - - // // Discard last 16 bytes (auth tag) - // let inner_plaintext = &app_data[..app_data.len()-16]; - // let (inner_content_type, _) = get_content_type_inner_plaintext( - // inner_plaintext - // ); - // if inner_content_type != TlsContentType::Handshake { - // // Silently ignore non-handshakes - // continue; - // } - // let (_, mut inner_handshakes) = complete( - // parse_inner_plaintext_for_handshake - // )(inner_plaintext).unwrap(); - - // // Sequentially process all handshakes - // let num_of_handshakes = inner_handshakes.len(); - // for _ in 0..num_of_handshakes { - // let (handshake_slice, handshake_repr) = inner_handshakes.remove(0); - // if self.process( - // handshake_slice, - // TlsRepr { - // content_type: TlsContentType::Handshake, - // version: repr.version, - // length: u16::try_from(handshake_repr.length).unwrap() + 4, - // payload: None, - // handshake: Some(handshake_repr) - // } - // ).is_err() { - // return (buffer_size, ()) - // } - // } - // } - - // else { - // if self.process(repr_slice, repr).is_err() { - // return (buffer_size, ()) - // } - // log::info!("Processed record"); - // } - // } - // (buffer_size, ()) - // } - // )?; } Ok(self.session.borrow().has_completed_handshake()) @@ -926,15 +812,23 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Check random: Cannot be SHA-256 of "HelloRetryRequest" if server_hello.random == HRR_RANDOM { - // Abort communication - todo!() + // Abort communication with illegal parameter alert + self.session.borrow_mut().invalidate_session( + AlertType::IllegalParameter, + handshake_slice + ); + return Ok(()); } // Check session_id_echo // The socket should have a session_id after moving from START state if !self.session.borrow().verify_session_id_echo(server_hello.session_id_echo) { - // Abort communication - todo!() + // Abort communication with illegal parameter alert + self.session.borrow_mut().invalidate_session( + AlertType::IllegalParameter, + handshake_slice + ); + return Ok(()); } // Note the selected cipher suite @@ -942,8 +836,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // TLSv13 forbidden key compression if server_hello.compression_method != 0 { - // Abort communciation - todo!() + // Abort communication with illegal parameter alert + self.session.borrow_mut().invalidate_session( + AlertType::IllegalParameter, + handshake_slice + ); + return Ok(()); } for extension in server_hello.extensions.iter() { @@ -954,12 +852,21 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } ) = extension.extension_data { if selected_version != TlsVersion::Tls13 { - // Abort for choosing not offered TLS version - todo!() + // Abort for choosing not offered TLS version, + // with illegal parameter alert + self.session.borrow_mut().invalidate_session( + AlertType::IllegalParameter, + handshake_slice + ); + return Ok(()); } } else { - // Abort for illegal extension - todo!() + // Abort for malformatted extension, with decode error alert + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } } @@ -985,8 +892,17 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { x25519_server_key ) ); + }, + // The client side implementation of this TLS socket only offers + // P-256 and x25519 as ECDHE key exchange algorithms + // Respond with illegal parameter alert and then terminate + _ => { + self.session.borrow_mut().invalidate_session( + AlertType::IllegalParameter, + handshake_slice + ); + return Ok(()); } - _ => todo!() } } } @@ -994,13 +910,24 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } else { // Handle invalid TLS packet - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } - // Check that both selected_cipher and (p256_public XNOR x25519_public) were received + // Check that both selected_cipher and (p256_public OR x25519_public) were received + // The case where key_share extension exists but no appropriate keys are returned + // is considered in above. The only remaining case is that the `key share` entry extension + // is not sent at all. if selected_cipher.is_none() || (p256_public.is_none() && x25519_public.is_none()) { // Abort communication - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::MissingExtension, + handshake_slice + ); + return Ok(()); } // Get slice without reserialization @@ -1022,14 +949,19 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Verify that it is indeed an EE let might_be_ee = repr.handshake.take().unwrap(); if might_be_ee.get_msg_type() != HandshakeType::EncryptedExtensions { - // Process the other handshakes in "handshake_vec" - todo!() + // Unexpected message types + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } - // TODO: Process payload - + // Possiblity: Process payload // Practically, nothing will be done about cookies/server name + // These fields are typically not session related // Extension processing is therefore skipped + // Update hash of the session, get EE by taking appropriate length of data // Length of handshake header is 4 let (_handshake_slice, ee_slice) = @@ -1079,8 +1011,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { cert.get_cert_public_key().unwrap() ); log::info!("Received WAIT_CERT_CR"); - } - else if might_be_cert.get_msg_type() == HandshakeType::CertificateRequest { + } else if might_be_cert.get_msg_type() == HandshakeType::CertificateRequest { // Process signature algorithm extension // Signature algorithm for the private key of client cert must be included // within the list of signature algorithms @@ -1121,17 +1052,23 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { ); } - } else { // Reject connection, CertificateRequest must have // SignatureAlgorithm extension - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::MissingExtension, + handshake_slice + ); + return Ok(()); } log::info!("Received WAIT_CERT_CR"); - } - else { - // Throw alert - todo!() + } else { + // Throw alert for not recving certificate/certificate request from server side + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } }, @@ -1172,7 +1109,11 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } else { // Unexpected handshakes // Throw alert - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } }, @@ -1183,7 +1124,11 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { let might_be_cert_verify = repr.handshake.take().unwrap(); if might_be_cert_verify.get_msg_type() != HandshakeType::CertificateVerify { // Process the other handshakes in "handshake_vec" - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } // Take out the portion for CertificateVerify @@ -1212,8 +1157,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Ensure that it is Finished let might_be_server_finished = repr.handshake.take().unwrap(); if might_be_server_finished.get_msg_type() != HandshakeType::Finished { - // Process the other handshakes in "handshake_vec" - todo!() + // Server Finished is expected. + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } // Take out the portion for server Finished @@ -1239,8 +1188,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Ensure that is a Client Hello let might_be_client_hello = repr.handshake.take().unwrap(); if might_be_client_hello.get_msg_type() != HandshakeType::ClientHello { - // Throw alert - todo!() + // Throw alert. Client Hello is expected. + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } // Process as Client Hello @@ -1265,9 +1218,14 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { if let Some(Some(nominated_cipher_suite)) = recognized_cipher_suite { nominated_cipher_suite } else { - // Not appropriate cipher found - // Send alert - todo!() + // No appropriate cipher found, + // the full set of security measures cannot be set up. + // Send alert for this + self.session.borrow_mut().invalidate_session( + AlertType::HandshakeFailure, + handshake_slice + ); + return Ok(()); } }; @@ -1300,20 +1258,36 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { { // TLS 1.3 was not offered by client // Reject connection immediately - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::IllegalParameter, + handshake_slice + ); + return Ok(()); } } else { // Wrong variant appeared, probably malformed - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } } else { // Malformed TLS packet - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } } else { // No supported_version extension was found, // Terminate by sending alert - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::MissingExtension, + handshake_slice + ); + return Ok(()); } // Check offered ECDHE algorithm @@ -1341,11 +1315,23 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } } else { // Malformed TLS packet - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } } else { - // Client did not offer ECDHE algorithm - todo!() + // Client did not offer ECDHE algorithm within `supported version` extension + // While it is allowed, HRR is not handled as acceptable parameters + // should have been offered already initially. + // Possibility: Tolerate minor mismatch of client hello, and send HRR instead + + self.session.borrow_mut().invalidate_session( + AlertType::MissingExtension, + handshake_slice + ); + return Ok(()); } // Select usable key @@ -1395,20 +1381,32 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } } - // If there are no applicable offered client key, - // consider sending a ClientHelloRetry + // There are no applicable offered client key, + // Proper way of handling: Send a HelloRetryRequest with key generated if ecdhe_public_key.is_none() { - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::HandshakeFailure, + handshake_slice + ); + return Ok(()); } } else { // Malformed packet // Send alert to client - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } } else { // The key_share extension was not sent - // Consider sending a ClientHelloRequest - todo!() + // Proper way of handling: Send a HelloRetryRequest with key generated + self.session.borrow_mut().invalidate_session( + AlertType::MissingExtension, + handshake_slice + ); + return Ok(()); } // Select signature algorithm @@ -1456,8 +1454,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { signature_algorithm = Some(*server_signature_algorithm); } else { // Cannot find a suitable signature algorithm for the server side - // Terminate the negotiation with alert - todo!() + // Terminate the negotiation with an alert + self.session.borrow_mut().invalidate_session( + AlertType::HandshakeFailure, + handshake_slice + ); + return Ok(()); } } else { @@ -1467,12 +1469,20 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } } else { // Malformed packet, type does not match content - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::DecodeError, + handshake_slice + ); + return Ok(()); } } else { // Will only accept authentication through certificate // Send alert if there are no signature algorithms extension - todo!() + self.session.borrow_mut().invalidate_session( + AlertType::MissingExtension, + handshake_slice + ); + return Ok(()); } { @@ -1494,8 +1504,12 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Ensure that it is Finished let might_be_client_finished = repr.handshake.take().unwrap(); if might_be_client_finished.get_msg_type() != HandshakeType::Finished { - // Process the other handshakes in "handshake_vec" - todo!() + // Expected to recv client finished + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); } // Take out the portion for server Finished @@ -1702,9 +1716,16 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { Ok(()) } - - pub fn get_tcp_handle(&self) -> SocketHandle { - self.tcp_handle - } - +} + +use core::fmt; +impl<'a, 'b, 'c> fmt::Write for TlsSocket<'a, 'b, 'c> { + fn write_str(&mut self, slice: &str) -> fmt::Result { + let slice = slice.as_bytes(); + if self.send_slice(slice) == Ok(()) { + Ok(()) + } else { + Err(fmt::Error) + } + } } diff --git a/src/tls_packet.rs b/src/tls_packet.rs index c880bde..97415ec 100644 --- a/src/tls_packet.rs +++ b/src/tls_packet.rs @@ -29,6 +29,40 @@ pub(crate) enum TlsContentType { ApplicationData = 23 } +#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] +#[repr(u8)] +pub(crate) enum AlertType { + CloseNotify = 0, + UnexpectedMessage = 10, + BadRecordMac = 20, + RecordOverflow = 22, + HandshakeFailure = 40, + BadCertificate = 42, + UnsupportedCertificate = 43, + CertificateRevoked = 44, + CertificateExpired = 45, + CertificateUnknown = 46, + IllegalParameter = 47, + UnknownCA = 48, + AccessDenied = 49, + DecodeError = 50, + DecryptError = 51, + ProtocolVersion = 70, + InsufficientSecurity = 71, + InternalError = 80, + InappropriateFallback = 86, + UserCanceled = 90, + MissingExtension = 109, + UnsupportedExtension = 110, + UnrecognizedName = 112, + BadCertificateStatusResponse = 113, + UnknownPSKIdentity = 115, + CertificateRequired = 116, + NoApplicationProtcol = 120, + #[num_enum(default)] + UnknownAlert = 255 +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u16)] pub(crate) enum TlsVersion { @@ -113,6 +147,24 @@ impl<'a> TlsRepr<'a> { self } + pub(crate) fn alert(mut self, alert: AlertType) -> Self { + self.content_type = TlsContentType::Alert; + self.version = TlsVersion::Tls12; + let mut application_data: Vec = Vec::new(); + match alert { + AlertType::CloseNotify | AlertType::UserCanceled => { + application_data.push(1) + }, + _ => { + application_data.push(2) + } + }; + application_data.push(alert.try_into().unwrap()); + self.length = 2; + self.payload = Some(application_data); + self + } + // TODO: Consider replace all these boolean function // into a single function that returns the HandshakeType. pub(crate) fn is_server_hello(&self) -> bool {