diff --git a/src/main.rs b/src/main.rs index 2f46d37..61c3b83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -57,30 +57,29 @@ impl TlsRng for CountingRng {} static mut RNG: CountingRng = CountingRng(0); fn main() { - // let mut socket_set_entries: [_; 8] = Default::default(); - // let mut sockets = SocketSet::new(&mut socket_set_entries[..]); + let mut socket_set_entries: [_; 8] = Default::default(); + let mut sockets = SocketSet::new(&mut socket_set_entries[..]); - // let mut tx_storage = [0; 4096]; - // let mut rx_storage = [0; 4096]; + let mut tx_storage = [0; 4096]; + let mut rx_storage = [0; 4096]; - // let mut tls_socket = unsafe { - // let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]); - // let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]); - // let tcp_socket = smoltcp::socket::TcpSocket::new(rx_buffer, tx_buffer); - // TlsSocket::new( - // tcp_socket, - // &mut RNG, - // None - // ) - // }; + let mut tls_socket = unsafe { + let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]); + let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]); + let tcp_socket = smoltcp::socket::TcpSocket::new(rx_buffer, tx_buffer); + TlsSocket::new( + tcp_socket, + &mut RNG, + None + ) + }; + + tls_socket.connect( + // &mut sockets, + (Ipv4Address::new(192, 168, 1, 125), 1883), + 49600 + ).unwrap(); - // tls_socket.connect( - // // &mut sockets, - // (Ipv4Address::new(192, 168, 1, 125), 1883), - // 49600 - // ).unwrap(); -} -/* // tls_socket.tls_connect(&mut sockets).unwrap(); simple_logger::SimpleLogger::new().init().unwrap(); @@ -361,4 +360,3 @@ const ED25519_SIGNATURE: [u8; 64] = hex_literal::hex!( "e9988fcc188fbe85a66929634badb47c5b765c3c6087a7e44b41efda1fdcd0baf67ded6159a5af6d396ca59439de8907160fc729a42ed50e69a3f54abe6dad0c" ); -*/ \ No newline at end of file diff --git a/src/session.rs b/src/session.rs index bafff38..b4d3637 100644 --- a/src/session.rs +++ b/src/session.rs @@ -84,7 +84,9 @@ pub(crate) struct Session<'a> { need_send_client_cert: bool, client_cert_verify_sig_alg: Option, // Flag for the need of sending alert to terminate TLS session - need_send_alert: Option + need_send_alert: Option, + // Flag for the need of sending certificate request to client from server + need_cert_req: bool } impl<'a> Session<'a> { @@ -125,7 +127,8 @@ impl<'a> Session<'a> { cert_private_key: certificate_with_key, need_send_client_cert: false, client_cert_verify_sig_alg: None, - need_send_alert: None + need_send_alert: None, + need_cert_req: false } } @@ -142,9 +145,10 @@ impl<'a> Session<'a> { self.remote_endpoint = remote_endpoint; } - pub(crate) fn listen(&mut self) { + pub(crate) fn listen(&mut self, require_cert_req: bool) { self.role = TlsRole::Server; self.state = TlsState::SERVER_START; + self.need_cert_req = require_cert_req; } pub(crate) fn invalidate_session( @@ -664,6 +668,13 @@ impl<'a> Session<'a> { self.hash.update(encryption_extension_slice); } + pub(crate) fn server_update_for_sent_certificate_request( + &mut self, + cert_request_slice: &[u8] + ) { + self.hash.update(cert_request_slice); + } + pub(crate) fn server_update_for_sent_certificate( &mut self, certificate_slice: &[u8] @@ -689,6 +700,189 @@ impl<'a> Session<'a> { self.find_application_keying_info(); // Change state + // It depends on the need to perform client authenication + if self.need_cert_req { + self.state = TlsState::SERVER_WAIT_CERT; + } else { + self.state = TlsState::SERVER_WAIT_FINISHED; + } + } + + pub(crate) fn server_update_for_wait_cert_cr( + &mut self, + cert_slice: &[u8], + cert_public_key: CertificatePublicKey + ) { + self.hash.update(cert_slice); + self.cert_public_key.replace(cert_public_key); + self.state = TlsState::SERVER_WAIT_CV; + } + + pub(crate) fn server_update_for_wait_cv( + &mut self, + cert_verify_slice: &[u8], + signature_algorithm: SignatureScheme, + signature: &[u8] + ) + { + // Clone the transcript hash from ClientHello all the way to Certificate + let transcript_hash: Vec = if let Ok(sha256) = self.hash.get_sha256_clone() { + Vec::from_slice(&sha256.finalize()).unwrap() + } else if let Ok(sha384) = self.hash.get_sha384_clone() { + Vec::from_slice(&sha384.finalize()).unwrap() + } else { + unreachable!() + }; + + // Handle Ed25519 and p256 separately + // These 2 algorithms have a mandated hash function + if signature_algorithm == SignatureScheme::ecdsa_secp256r1_sha256 { + let verify_hash = Sha256::new() + .chain(&[0x20; 64]) + .chain("TLS 1.3, client CertificateVerify") + .chain(&[0]) + .chain(&transcript_hash); + let ecdsa_signature = p256::ecdsa::Signature::from_asn1(signature).unwrap(); + self.cert_public_key + .take() + .unwrap() + .get_ecdsa_secp256r1_sha256_verify_key() + .unwrap() + .verify_digest( + verify_hash, &ecdsa_signature + ).unwrap(); + + // Usual procedures: update hash + self.hash.update(cert_verify_slice); + // At last, update client state + self.state = TlsState::SERVER_WAIT_FINISHED; + return; + } + + // ED25519 only accepts PureEdDSA implementation + if signature_algorithm == SignatureScheme::ed25519 { + // 64 bytes of 0x20 + // 33 bytes of text + // 1 byte of 0 + // potentially 48 bytes of transcript hash + // 146 bytes in total + let mut verify_message: Vec = Vec::new(); + verify_message.extend_from_slice(&[0x20; 64]).unwrap(); + verify_message.extend_from_slice(b"TLS 1.3, client CertificateVerify").unwrap(); + verify_message.extend_from_slice(&[0]).unwrap(); + verify_message.extend_from_slice(&transcript_hash).unwrap(); + let ed25519_signature = ed25519_dalek::Signature::try_from( + signature + ).unwrap(); + self.cert_public_key.take() + .unwrap() + .get_ed25519_public_key() + .unwrap() + .verify_strict(&verify_message, &ed25519_signature) + .unwrap(); + + // Usual procedures: update hash + self.hash.update(cert_verify_slice); + // At last, update client state + self.state = TlsState::SERVER_WAIT_FINISHED; + return; + } + + // Get verification hash, and verify the signature + use crate::tls_packet::SignatureScheme::*; + + let get_rsa_padding_scheme = |sig_alg: SignatureScheme| -> PaddingScheme { + match sig_alg { + rsa_pkcs1_sha256 => { + PaddingScheme::new_pkcs1v15_sign(Some(RSAHash::SHA2_256)) + }, + rsa_pkcs1_sha384 => { + PaddingScheme::new_pkcs1v15_sign(Some(RSAHash::SHA2_384)) + }, + rsa_pkcs1_sha512 => { + PaddingScheme::new_pkcs1v15_sign(Some(RSAHash::SHA2_512)) + }, + rsa_pss_rsae_sha256 | rsa_pss_pss_sha256 => { + PaddingScheme::new_pss::(FakeRandom{}) + }, + rsa_pss_rsae_sha384 | rsa_pss_pss_sha384 => { + PaddingScheme::new_pss::(FakeRandom{}) + }, + rsa_pss_rsae_sha512 | rsa_pss_pss_sha512 => { + PaddingScheme::new_pss::(FakeRandom{}) + }, + _ => unreachable!() + } + }; + + match signature_algorithm { + rsa_pkcs1_sha256 | rsa_pss_rsae_sha256 | rsa_pss_pss_sha256 => { + let verify_hash = Sha256::new() + .chain(&[0x20; 64]) + .chain("TLS 1.3, client CertificateVerify") + .chain(&[0]) + .chain(&transcript_hash) + .finalize(); + let padding = get_rsa_padding_scheme(signature_algorithm); + let verify_result = self.cert_public_key + .take() + .unwrap() + .get_rsa_public_key() + .unwrap() + .verify( + padding, &verify_hash, signature + ); + if verify_result.is_err() { + todo!() + } + }, + rsa_pkcs1_sha384 | rsa_pss_rsae_sha384 | rsa_pss_pss_sha384 => { + let verify_hash = Sha384::new() + .chain(&[0x20; 64]) + .chain("TLS 1.3, client CertificateVerify") + .chain(&[0]) + .chain(&transcript_hash) + .finalize(); + let padding = get_rsa_padding_scheme(signature_algorithm); + let verify_result = self.cert_public_key + .take() + .unwrap() + .get_rsa_public_key() + .unwrap() + .verify( + padding, &verify_hash, signature + ); + if verify_result.is_err() { + todo!() + } + }, + rsa_pkcs1_sha512 | rsa_pss_rsae_sha512 | rsa_pss_pss_sha512 => { + let verify_hash = Sha512::new() + .chain(&[0x20; 64]) + .chain("TLS 1.3, client CertificateVerify") + .chain(&[0]) + .chain(&transcript_hash) + .finalize(); + let padding = get_rsa_padding_scheme(signature_algorithm); + let verify_result = self.cert_public_key + .take() + .unwrap() + .get_rsa_public_key() + .unwrap() + .verify( + padding, &verify_hash, signature + ); + if verify_result.is_err() { + todo!() + } + }, + _ => unreachable!() + }; + + // Usual procedures: update hash + self.hash.update(cert_verify_slice); + + // At last, update client state self.state = TlsState::SERVER_WAIT_FINISHED; } @@ -1426,6 +1620,10 @@ impl<'a> Session<'a> { self.need_send_client_cert } + pub(crate) fn need_to_send_cert_request(&self) -> bool { + self.need_cert_req + } + pub(crate) fn get_private_certificate_slices(&self) -> Option<&alloc::vec::Vec<&[u8]>> { if let Some((_, cert_vec)) = &self.cert_private_key { Some(cert_vec) diff --git a/src/tls.rs b/src/tls.rs index ae0de93..94e9f96 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -127,6 +127,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { pub fn listen( &mut self, + issue_client_verification: bool, local_endpoint: T ) -> Result<()> where @@ -138,7 +139,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { // Update tls session to server_start let mut session = self.session.borrow_mut(); - session.listen(); + session.listen(issue_client_verification); Ok(()) } @@ -661,7 +662,72 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { log::info!("sent encrypted extension"); - // TODO: Option to allow a certificate request + // Send certificate request to client, on user's discretion + if self.session.borrow().need_to_send_cert_request() { + let mut inner_plaintext: Vec = Vec::new(); + inner_plaintext.extend_from_slice(&[ + 13, // Certificate request + 0, 0, 0, // Dummy length + 0, // No certificate request context + 0, 0, // Dummy extensions length + 0, 13, // Signature Algorithm extension type + 0, 0, // Dummy extension length + 0, 0 // Dummy signature scheme list length + ]); + let supported_sig_algs = [ + SignatureScheme::ecdsa_secp256r1_sha256, + SignatureScheme::ed25519, + SignatureScheme::rsa_pss_pss_sha256, + SignatureScheme::rsa_pkcs1_sha256, + SignatureScheme::rsa_pss_rsae_sha256, + SignatureScheme::rsa_pss_pss_sha384, + SignatureScheme::rsa_pkcs1_sha384, + SignatureScheme::rsa_pss_rsae_sha384, + SignatureScheme::rsa_pss_pss_sha512, + SignatureScheme::rsa_pkcs1_sha512, + SignatureScheme::rsa_pss_rsae_sha512 + ]; + + NetworkEndian::write_u16( + &mut inner_plaintext[11..13], + u16::try_from(supported_sig_algs.len() * 2).unwrap() + ); + + NetworkEndian::write_u16( + &mut inner_plaintext[9..11], + u16::try_from(supported_sig_algs.len() * 2 + 2).unwrap() + ); + + NetworkEndian::write_u16( + &mut inner_plaintext[5..7], + u16::try_from(supported_sig_algs.len() * 2 + 6).unwrap() + ); + + NetworkEndian::write_u24( + &mut inner_plaintext[1..4], + u32::try_from(supported_sig_algs.len() * 2 + 9).unwrap() + ); + + for sig_alg in supported_sig_algs.iter() { + inner_plaintext.extend_from_slice( + &u16::try_from(*sig_alg).unwrap().to_be_bytes() + ); + } + + // Push content type: Handshake + inner_plaintext.push(22); + + self.send_application_slice(&mut inner_plaintext.clone())?; + let inner_plaintext_length = inner_plaintext.len(); + { + let mut session = self.session.borrow_mut(); + session.server_update_for_sent_certificate_request( + &inner_plaintext[..(inner_plaintext_length-1)] + ); + } + + log::info!("sent certificate request"); + } // Construct and send server certificate handshake content let inner_plaintext = { @@ -1693,6 +1759,86 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> { } }, + // Receive client certificate as server + TlsState::SERVER_WAIT_CERT => { + // Verify that it is indeed an Certificate + let might_be_cert = repr.handshake.take().unwrap(); + + if might_be_cert.get_msg_type() == HandshakeType::Certificate { + // Process certificates + + // let all_certificates = might_be_cert.get_all_asn1_der_certificates().unwrap(); + // log::info!("Number of certificates: {:?}", all_certificates.len()); + // log::info!("All certificates: {:?}", all_certificates); + + // TODO: Process all certificates + // TODO: Conditionally allow client to send empty certificate + let cert = might_be_cert.get_asn1_der_certificate().unwrap(); + + // TODO: Replace this block after implementing a proper + // certificate verification procdeure + cert.validate_self_signed_signature().expect("Signature mismatched"); + + // Update session TLS state to WAIT_CV + // Length of handshake header is 4 + let (_handshake_slice, cert_slice) = + take::<_, _, (&[u8], ErrorKind)>( + might_be_cert.length + 4 + )(handshake_slice) + .map_err(|_| Error::Unrecognized)?; + + self.session.borrow_mut() + .server_update_for_wait_cert_cr( + &cert_slice, + cert.get_cert_public_key().unwrap() + ); + log::info!("Received WAIT_CERT"); + } else { + // Unexpected handshakes + // Throw alert + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); + } + }, + + // Receive client certificate verify as server + // Verify the signature to the hash + TlsState::SERVER_WAIT_CV => { + // Ensure that it is CertificateVerify + let might_be_cert_verify = repr.handshake.take().unwrap(); + if might_be_cert_verify.get_msg_type() != HandshakeType::CertificateVerify { + // Throw alert to terminate handshake if it is not certificate verify + self.session.borrow_mut().invalidate_session( + AlertType::UnexpectedMessage, + handshake_slice + ); + return Ok(()); + } + + // Take out the portion for CertificateVerify + // Length of handshake header is 4 + let (_handshake_slice, cert_verify_slice) = + take::<_, _, (&[u8], ErrorKind)>( + might_be_cert_verify.length + 4 + )(handshake_slice) + .map_err(|_| Error::Unrecognized)?; + + // Perform verification, update TLS state if successful + let (sig_alg, signature) = might_be_cert_verify.get_signature().unwrap(); + { + self.session.borrow_mut() + .server_update_for_wait_cv( + cert_verify_slice, + sig_alg, + signature + ); + } + log::info!("Received CV"); + }, + TlsState::SERVER_WAIT_FINISHED => { // Ensure that it is Finished let might_be_client_finished = repr.handshake.take().unwrap(); diff --git a/src/tls_packet.rs b/src/tls_packet.rs index cfbfb71..12787f7 100644 --- a/src/tls_packet.rs +++ b/src/tls_packet.rs @@ -346,6 +346,7 @@ impl<'a> HandshakeData<'a> { match self { HandshakeData::ClientHello(data) => data.get_length(), HandshakeData::ServerHello(data) => data.get_length(), + HandshakeData::CertificateRequest(cr) => cr.get_length(), _ => 0, } } @@ -525,7 +526,6 @@ impl ClientHello { pub(crate) fn finalise(mut self) -> Self { let mut sum = 0; for extension in self.extensions.iter() { - // TODO: Add up the extension length sum += extension.get_length(); } self.extension_length = sum.try_into().unwrap(); @@ -995,3 +995,10 @@ pub(crate) struct CertificateRequest<'a> { pub(crate) extensions_length: u16, pub(crate) extensions: Vec, } + +impl<'a> CertificateRequest<'a> { + fn get_length(&self) -> usize { + usize::try_from(self.certificate_request_context_length).unwrap() + + usize::try_from(self.extensions_length).unwrap() + 3 + } +}