diff --git a/src/session.rs b/src/session.rs index 362ad1c..f720555 100644 --- a/src/session.rs +++ b/src/session.rs @@ -83,6 +83,8 @@ pub(crate) struct Session<'a> { // Client must cent Certificate extension iff server requested it need_send_client_cert: bool, client_cert_verify_sig_alg: Option, + // Flag for the need of sending alert to terminate TLS session + need_send_alert: Option } impl<'a> Session<'a> { @@ -122,7 +124,8 @@ impl<'a> Session<'a> { cert_public_key: None, cert_private_key: certificate_with_key, need_send_client_cert: false, - client_cert_verify_sig_alg: None + client_cert_verify_sig_alg: None, + need_send_alert: None } } @@ -150,7 +153,13 @@ impl<'a> Session<'a> { received_slice: &[u8] ) { self.hash.update(received_slice); - self.state = TlsState::NEED_RESET(alert); + self.need_send_alert = Some(alert); + } + + pub(crate) fn reset_state(&mut self) { + // Clear alert + self.need_send_alert = None; + self.state = TlsState::DEFAULT; } // State transition from START to WAIT_SH @@ -1405,6 +1414,10 @@ impl<'a> Session<'a> { self.remote_endpoint } + pub(crate) fn get_need_send_alert(&self) -> Option { + self.need_send_alert + } + pub(crate) fn has_completed_handshake(&self) -> bool { self.state == TlsState::CLIENT_CONNECTED } diff --git a/src/tls.rs b/src/tls.rs index 0f9535d..0b64b24 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -57,9 +57,7 @@ pub(crate) enum TlsState { SERVER_WAIT_CERT, SERVER_WAIT_CV, SERVER_WAIT_FINISHED, - SERVER_CONNECTED, - // `Derailed` state should any exceptions occur - NEED_RESET(AlertType) + SERVER_CONNECTED } pub struct TlsSocket<'a, 'b, 'c> @@ -150,10 +148,20 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { self.session.borrow().get_tls_state() }; + let need_send_alert = { + self.session.borrow().get_need_send_alert() + }; + // Check TCP socket/ TLS session { let tcp_state = self.sockets.get::(self.tcp_handle).state(); + //Close TCP socket if necessary + if tcp_state == TcpState::Established && tls_state == TlsState::DEFAULT { + self.sockets.get::(self.tcp_handle).close(); + return Ok(false); + } + // Skip handshake processing if it is already completed // However, redo TCP handshake if TLS socket is trying to connect and // TCP socket is not connected @@ -173,6 +181,117 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { return Ok(false); } } + + // Send alert to start terminating TLS session if necessary + if let Some(alert) = need_send_alert { + match tls_state { + // Client side socket: + // States that expects plaintext payload + TlsState::WAIT_SH | TlsState::SERVER_START => { + // Send the cooresponding alert in plaintext + let mut tcp_socket = self.sockets.get::(self.tcp_handle); + tcp_socket.send( + |data| { + // Set up a TLS buffer on the internal buffer of TCP socket + let mut buffer = TlsBuffer::new(data); + // Instantiate a TLS bytes-representation with pre-determined alert + let tls_repr = TlsRepr::new().alert(alert); + if buffer.enqueue_tls_repr(tls_repr).is_err() { + return (0, ()) + } + + let slice: &[u8] = buffer.into(); + (slice.len(), ()) + } + )?; + }, + // States that expects enrypted payload using handshake secret + TlsState::WAIT_EE | + TlsState::WAIT_CERT_CR | + TlsState::CLIENT_WAIT_CERT | + TlsState::CLIENT_WAIT_CV | + TlsState::CLIENT_WAIT_FINISHED | + TlsState::SERVER_COMPLETED | + TlsState::NEGOTIATED | + TlsState::WAIT_FLIGHT | + TlsState::SERVER_WAIT_CERT | + TlsState::SERVER_WAIT_CV => { + // Send the corresponding alert in ciphertext using handshake secret + let severity: u8 = match alert { + AlertType::CloseNotify | AlertType::UserCanceled => { + 1 + }, + _ => 2 + }; + let mut alert_array: [u8; 3] = [ + severity, + u8::try_from(alert).unwrap(), + 21 // Alert content type + ]; + self.send_application_slice(&mut alert_array)?; + }, + // States that expects enrypted payload using application data secret + TlsState::CLIENT_CONNECTED | + TlsState::SERVER_WAIT_FINISHED | + TlsState::SERVER_CONNECTED => { + // Send the corresponding alert in ciphertext using application data secret + // Sending order: + // 1. Associated data/ TLS Record layer + // 2. Encrypted { Alert } + // 3. Authentication tag (16 bytes for all supported AEADs) + let mut associated_data: [u8; 5] = [ + 0x17, // Application data + 0x03, 0x03, // TLS 1.3 record disguised as TLS 1.2 + 0x00, 0x00 // Length of encrypted data, yet to be determined conveniently + ]; + + NetworkEndian::write_u16(&mut associated_data[3..5], + 2 // Payload length + + 1 // Content type length + + 16 // Auth tag length + ); + + // Alert: Warning (1) , Close notify (0) + let severity: u8 = match alert { + AlertType::CloseNotify | AlertType::UserCanceled => { + 1 + }, + _ => 2 + }; + let mut alert_array: [u8; 3] = [ + severity, + u8::try_from(alert).unwrap(), + 21 // Alert content type + ]; + + let mut session = self.session.borrow_mut(); + let tag = session.encrypt_application_data_in_place_detached( + &associated_data, + &mut alert_array + ).unwrap(); + session.increment_local_sequence_number(); + + let mut tcp_socket = self.sockets.get::(self.tcp_handle); + if !tcp_socket.can_send() { + return Err(Error::Illegal); + } + + tcp_socket.send_slice(&associated_data)?; + tcp_socket.send_slice(&alert_array)?; + tcp_socket.send_slice(&tag)?; + }, + // Other states, such as client_start and default should never send alert + // These stages are too early to raise exceptions + _ => unreachable!() + } + + // Finally, revert the FSM to DEFAULT to signal an invokation of + // `close()` to the TCP socket + self.session.borrow_mut().reset_state(); + + return Ok(false); + } + // Handle TLS handshake through TLS states match tls_state { // Do nothing on the default state @@ -654,13 +773,6 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { return Ok(true); } - // Terminate TLS connection with an alert - // Terminate TCP session by issuing `close()` - // Reset the socket - TlsState::NEED_RESET(alert) => { - todo!() - } - // Other states _ => {} } @@ -1716,6 +1828,34 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { Ok(()) } + + // Send `Close notify` alert to remote side + // Set state to `CLOSED` + // Leave TCP termination to polling method + pub fn close(&mut self) -> Result<()> { + let mut session = self.session.borrow_mut(); + match session.get_tls_state() { + // Send a `close notify` if handshake is established + TlsState::CLIENT_CONNECTED | TlsState::SERVER_CONNECTED => { + session.invalidate_session( + AlertType::CloseNotify, + &[] + ); + }, + // Do nothing if handshake hasn't even started + TlsState::DEFAULT => {}, + // Send `user cancaled` to cancel the handshake negotiation + // if it is currently in the middle of one + _ => { + session.invalidate_session( + AlertType::UserCanceled, + &[] + ); + } + } + + Ok(()) + } } use core::fmt;