tls: add request client auth option for listen

This commit is contained in:
occheung 2020-12-08 11:09:41 +08:00
parent 0c6807f593
commit b6d8428cb1
4 changed files with 377 additions and 28 deletions

View File

@ -57,30 +57,29 @@ impl TlsRng for CountingRng {}
static mut RNG: CountingRng = CountingRng(0); static mut RNG: CountingRng = CountingRng(0);
fn main() { fn main() {
// let mut socket_set_entries: [_; 8] = Default::default(); let mut socket_set_entries: [_; 8] = Default::default();
// let mut sockets = SocketSet::new(&mut socket_set_entries[..]); let mut sockets = SocketSet::new(&mut socket_set_entries[..]);
// let mut tx_storage = [0; 4096]; let mut tx_storage = [0; 4096];
// let mut rx_storage = [0; 4096]; let mut rx_storage = [0; 4096];
// let mut tls_socket = unsafe { let mut tls_socket = unsafe {
// let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]); let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]);
// let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]); let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]);
// let tcp_socket = smoltcp::socket::TcpSocket::new(rx_buffer, tx_buffer); let tcp_socket = smoltcp::socket::TcpSocket::new(rx_buffer, tx_buffer);
// TlsSocket::new( TlsSocket::new(
// tcp_socket, tcp_socket,
// &mut RNG, &mut RNG,
// None None
// ) )
// }; };
tls_socket.connect(
// &mut sockets,
(Ipv4Address::new(192, 168, 1, 125), 1883),
49600
).unwrap();
// tls_socket.connect(
// // &mut sockets,
// (Ipv4Address::new(192, 168, 1, 125), 1883),
// 49600
// ).unwrap();
}
/*
// tls_socket.tls_connect(&mut sockets).unwrap(); // tls_socket.tls_connect(&mut sockets).unwrap();
simple_logger::SimpleLogger::new().init().unwrap(); simple_logger::SimpleLogger::new().init().unwrap();
@ -361,4 +360,3 @@ const ED25519_SIGNATURE: [u8; 64] =
hex_literal::hex!( hex_literal::hex!(
"e9988fcc188fbe85a66929634badb47c5b765c3c6087a7e44b41efda1fdcd0baf67ded6159a5af6d396ca59439de8907160fc729a42ed50e69a3f54abe6dad0c" "e9988fcc188fbe85a66929634badb47c5b765c3c6087a7e44b41efda1fdcd0baf67ded6159a5af6d396ca59439de8907160fc729a42ed50e69a3f54abe6dad0c"
); );
*/

View File

@ -84,7 +84,9 @@ pub(crate) struct Session<'a> {
need_send_client_cert: bool, need_send_client_cert: bool,
client_cert_verify_sig_alg: Option<crate::tls_packet::SignatureScheme>, client_cert_verify_sig_alg: Option<crate::tls_packet::SignatureScheme>,
// Flag for the need of sending alert to terminate TLS session // Flag for the need of sending alert to terminate TLS session
need_send_alert: Option<AlertType> need_send_alert: Option<AlertType>,
// Flag for the need of sending certificate request to client from server
need_cert_req: bool
} }
impl<'a> Session<'a> { impl<'a> Session<'a> {
@ -125,7 +127,8 @@ impl<'a> Session<'a> {
cert_private_key: certificate_with_key, cert_private_key: certificate_with_key,
need_send_client_cert: false, need_send_client_cert: false,
client_cert_verify_sig_alg: None, client_cert_verify_sig_alg: None,
need_send_alert: None need_send_alert: None,
need_cert_req: false
} }
} }
@ -142,9 +145,10 @@ impl<'a> Session<'a> {
self.remote_endpoint = remote_endpoint; self.remote_endpoint = remote_endpoint;
} }
pub(crate) fn listen(&mut self) { pub(crate) fn listen(&mut self, require_cert_req: bool) {
self.role = TlsRole::Server; self.role = TlsRole::Server;
self.state = TlsState::SERVER_START; self.state = TlsState::SERVER_START;
self.need_cert_req = require_cert_req;
} }
pub(crate) fn invalidate_session( pub(crate) fn invalidate_session(
@ -664,6 +668,13 @@ impl<'a> Session<'a> {
self.hash.update(encryption_extension_slice); self.hash.update(encryption_extension_slice);
} }
pub(crate) fn server_update_for_sent_certificate_request(
&mut self,
cert_request_slice: &[u8]
) {
self.hash.update(cert_request_slice);
}
pub(crate) fn server_update_for_sent_certificate( pub(crate) fn server_update_for_sent_certificate(
&mut self, &mut self,
certificate_slice: &[u8] certificate_slice: &[u8]
@ -689,6 +700,189 @@ impl<'a> Session<'a> {
self.find_application_keying_info(); self.find_application_keying_info();
// Change state // Change state
// It depends on the need to perform client authenication
if self.need_cert_req {
self.state = TlsState::SERVER_WAIT_CERT;
} else {
self.state = TlsState::SERVER_WAIT_FINISHED;
}
}
pub(crate) fn server_update_for_wait_cert_cr(
&mut self,
cert_slice: &[u8],
cert_public_key: CertificatePublicKey
) {
self.hash.update(cert_slice);
self.cert_public_key.replace(cert_public_key);
self.state = TlsState::SERVER_WAIT_CV;
}
pub(crate) fn server_update_for_wait_cv(
&mut self,
cert_verify_slice: &[u8],
signature_algorithm: SignatureScheme,
signature: &[u8]
)
{
// Clone the transcript hash from ClientHello all the way to Certificate
let transcript_hash: Vec<u8, U64> = if let Ok(sha256) = self.hash.get_sha256_clone() {
Vec::from_slice(&sha256.finalize()).unwrap()
} else if let Ok(sha384) = self.hash.get_sha384_clone() {
Vec::from_slice(&sha384.finalize()).unwrap()
} else {
unreachable!()
};
// Handle Ed25519 and p256 separately
// These 2 algorithms have a mandated hash function
if signature_algorithm == SignatureScheme::ecdsa_secp256r1_sha256 {
let verify_hash = Sha256::new()
.chain(&[0x20; 64])
.chain("TLS 1.3, client CertificateVerify")
.chain(&[0])
.chain(&transcript_hash);
let ecdsa_signature = p256::ecdsa::Signature::from_asn1(signature).unwrap();
self.cert_public_key
.take()
.unwrap()
.get_ecdsa_secp256r1_sha256_verify_key()
.unwrap()
.verify_digest(
verify_hash, &ecdsa_signature
).unwrap();
// Usual procedures: update hash
self.hash.update(cert_verify_slice);
// At last, update client state
self.state = TlsState::SERVER_WAIT_FINISHED;
return;
}
// ED25519 only accepts PureEdDSA implementation
if signature_algorithm == SignatureScheme::ed25519 {
// 64 bytes of 0x20
// 33 bytes of text
// 1 byte of 0
// potentially 48 bytes of transcript hash
// 146 bytes in total
let mut verify_message: Vec<u8, U146> = Vec::new();
verify_message.extend_from_slice(&[0x20; 64]).unwrap();
verify_message.extend_from_slice(b"TLS 1.3, client CertificateVerify").unwrap();
verify_message.extend_from_slice(&[0]).unwrap();
verify_message.extend_from_slice(&transcript_hash).unwrap();
let ed25519_signature = ed25519_dalek::Signature::try_from(
signature
).unwrap();
self.cert_public_key.take()
.unwrap()
.get_ed25519_public_key()
.unwrap()
.verify_strict(&verify_message, &ed25519_signature)
.unwrap();
// Usual procedures: update hash
self.hash.update(cert_verify_slice);
// At last, update client state
self.state = TlsState::SERVER_WAIT_FINISHED;
return;
}
// Get verification hash, and verify the signature
use crate::tls_packet::SignatureScheme::*;
let get_rsa_padding_scheme = |sig_alg: SignatureScheme| -> PaddingScheme {
match sig_alg {
rsa_pkcs1_sha256 => {
PaddingScheme::new_pkcs1v15_sign(Some(RSAHash::SHA2_256))
},
rsa_pkcs1_sha384 => {
PaddingScheme::new_pkcs1v15_sign(Some(RSAHash::SHA2_384))
},
rsa_pkcs1_sha512 => {
PaddingScheme::new_pkcs1v15_sign(Some(RSAHash::SHA2_512))
},
rsa_pss_rsae_sha256 | rsa_pss_pss_sha256 => {
PaddingScheme::new_pss::<Sha256, FakeRandom>(FakeRandom{})
},
rsa_pss_rsae_sha384 | rsa_pss_pss_sha384 => {
PaddingScheme::new_pss::<Sha384, FakeRandom>(FakeRandom{})
},
rsa_pss_rsae_sha512 | rsa_pss_pss_sha512 => {
PaddingScheme::new_pss::<Sha512, FakeRandom>(FakeRandom{})
},
_ => unreachable!()
}
};
match signature_algorithm {
rsa_pkcs1_sha256 | rsa_pss_rsae_sha256 | rsa_pss_pss_sha256 => {
let verify_hash = Sha256::new()
.chain(&[0x20; 64])
.chain("TLS 1.3, client CertificateVerify")
.chain(&[0])
.chain(&transcript_hash)
.finalize();
let padding = get_rsa_padding_scheme(signature_algorithm);
let verify_result = self.cert_public_key
.take()
.unwrap()
.get_rsa_public_key()
.unwrap()
.verify(
padding, &verify_hash, signature
);
if verify_result.is_err() {
todo!()
}
},
rsa_pkcs1_sha384 | rsa_pss_rsae_sha384 | rsa_pss_pss_sha384 => {
let verify_hash = Sha384::new()
.chain(&[0x20; 64])
.chain("TLS 1.3, client CertificateVerify")
.chain(&[0])
.chain(&transcript_hash)
.finalize();
let padding = get_rsa_padding_scheme(signature_algorithm);
let verify_result = self.cert_public_key
.take()
.unwrap()
.get_rsa_public_key()
.unwrap()
.verify(
padding, &verify_hash, signature
);
if verify_result.is_err() {
todo!()
}
},
rsa_pkcs1_sha512 | rsa_pss_rsae_sha512 | rsa_pss_pss_sha512 => {
let verify_hash = Sha512::new()
.chain(&[0x20; 64])
.chain("TLS 1.3, client CertificateVerify")
.chain(&[0])
.chain(&transcript_hash)
.finalize();
let padding = get_rsa_padding_scheme(signature_algorithm);
let verify_result = self.cert_public_key
.take()
.unwrap()
.get_rsa_public_key()
.unwrap()
.verify(
padding, &verify_hash, signature
);
if verify_result.is_err() {
todo!()
}
},
_ => unreachable!()
};
// Usual procedures: update hash
self.hash.update(cert_verify_slice);
// At last, update client state
self.state = TlsState::SERVER_WAIT_FINISHED; self.state = TlsState::SERVER_WAIT_FINISHED;
} }
@ -1426,6 +1620,10 @@ impl<'a> Session<'a> {
self.need_send_client_cert self.need_send_client_cert
} }
pub(crate) fn need_to_send_cert_request(&self) -> bool {
self.need_cert_req
}
pub(crate) fn get_private_certificate_slices(&self) -> Option<&alloc::vec::Vec<&[u8]>> { pub(crate) fn get_private_certificate_slices(&self) -> Option<&alloc::vec::Vec<&[u8]>> {
if let Some((_, cert_vec)) = &self.cert_private_key { if let Some((_, cert_vec)) = &self.cert_private_key {
Some(cert_vec) Some(cert_vec)

View File

@ -127,6 +127,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
pub fn listen<T>( pub fn listen<T>(
&mut self, &mut self,
issue_client_verification: bool,
local_endpoint: T local_endpoint: T
) -> Result<()> ) -> Result<()>
where where
@ -138,7 +139,7 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
// Update tls session to server_start // Update tls session to server_start
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
session.listen(); session.listen(issue_client_verification);
Ok(()) Ok(())
} }
@ -661,7 +662,72 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
log::info!("sent encrypted extension"); log::info!("sent encrypted extension");
// TODO: Option to allow a certificate request // Send certificate request to client, on user's discretion
if self.session.borrow().need_to_send_cert_request() {
let mut inner_plaintext: Vec<u8> = Vec::new();
inner_plaintext.extend_from_slice(&[
13, // Certificate request
0, 0, 0, // Dummy length
0, // No certificate request context
0, 0, // Dummy extensions length
0, 13, // Signature Algorithm extension type
0, 0, // Dummy extension length
0, 0 // Dummy signature scheme list length
]);
let supported_sig_algs = [
SignatureScheme::ecdsa_secp256r1_sha256,
SignatureScheme::ed25519,
SignatureScheme::rsa_pss_pss_sha256,
SignatureScheme::rsa_pkcs1_sha256,
SignatureScheme::rsa_pss_rsae_sha256,
SignatureScheme::rsa_pss_pss_sha384,
SignatureScheme::rsa_pkcs1_sha384,
SignatureScheme::rsa_pss_rsae_sha384,
SignatureScheme::rsa_pss_pss_sha512,
SignatureScheme::rsa_pkcs1_sha512,
SignatureScheme::rsa_pss_rsae_sha512
];
NetworkEndian::write_u16(
&mut inner_plaintext[11..13],
u16::try_from(supported_sig_algs.len() * 2).unwrap()
);
NetworkEndian::write_u16(
&mut inner_plaintext[9..11],
u16::try_from(supported_sig_algs.len() * 2 + 2).unwrap()
);
NetworkEndian::write_u16(
&mut inner_plaintext[5..7],
u16::try_from(supported_sig_algs.len() * 2 + 6).unwrap()
);
NetworkEndian::write_u24(
&mut inner_plaintext[1..4],
u32::try_from(supported_sig_algs.len() * 2 + 9).unwrap()
);
for sig_alg in supported_sig_algs.iter() {
inner_plaintext.extend_from_slice(
&u16::try_from(*sig_alg).unwrap().to_be_bytes()
);
}
// Push content type: Handshake
inner_plaintext.push(22);
self.send_application_slice(&mut inner_plaintext.clone())?;
let inner_plaintext_length = inner_plaintext.len();
{
let mut session = self.session.borrow_mut();
session.server_update_for_sent_certificate_request(
&inner_plaintext[..(inner_plaintext_length-1)]
);
}
log::info!("sent certificate request");
}
// Construct and send server certificate handshake content // Construct and send server certificate handshake content
let inner_plaintext = { let inner_plaintext = {
@ -1693,6 +1759,86 @@ impl<'a, 'b, 'c> TlsSocket<'a, 'b, 'c> {
} }
}, },
// Receive client certificate as server
TlsState::SERVER_WAIT_CERT => {
// Verify that it is indeed an Certificate
let might_be_cert = repr.handshake.take().unwrap();
if might_be_cert.get_msg_type() == HandshakeType::Certificate {
// Process certificates
// let all_certificates = might_be_cert.get_all_asn1_der_certificates().unwrap();
// log::info!("Number of certificates: {:?}", all_certificates.len());
// log::info!("All certificates: {:?}", all_certificates);
// TODO: Process all certificates
// TODO: Conditionally allow client to send empty certificate
let cert = might_be_cert.get_asn1_der_certificate().unwrap();
// TODO: Replace this block after implementing a proper
// certificate verification procdeure
cert.validate_self_signed_signature().expect("Signature mismatched");
// Update session TLS state to WAIT_CV
// Length of handshake header is 4
let (_handshake_slice, cert_slice) =
take::<_, _, (&[u8], ErrorKind)>(
might_be_cert.length + 4
)(handshake_slice)
.map_err(|_| Error::Unrecognized)?;
self.session.borrow_mut()
.server_update_for_wait_cert_cr(
&cert_slice,
cert.get_cert_public_key().unwrap()
);
log::info!("Received WAIT_CERT");
} else {
// Unexpected handshakes
// Throw alert
self.session.borrow_mut().invalidate_session(
AlertType::UnexpectedMessage,
handshake_slice
);
return Ok(());
}
},
// Receive client certificate verify as server
// Verify the signature to the hash
TlsState::SERVER_WAIT_CV => {
// Ensure that it is CertificateVerify
let might_be_cert_verify = repr.handshake.take().unwrap();
if might_be_cert_verify.get_msg_type() != HandshakeType::CertificateVerify {
// Throw alert to terminate handshake if it is not certificate verify
self.session.borrow_mut().invalidate_session(
AlertType::UnexpectedMessage,
handshake_slice
);
return Ok(());
}
// Take out the portion for CertificateVerify
// Length of handshake header is 4
let (_handshake_slice, cert_verify_slice) =
take::<_, _, (&[u8], ErrorKind)>(
might_be_cert_verify.length + 4
)(handshake_slice)
.map_err(|_| Error::Unrecognized)?;
// Perform verification, update TLS state if successful
let (sig_alg, signature) = might_be_cert_verify.get_signature().unwrap();
{
self.session.borrow_mut()
.server_update_for_wait_cv(
cert_verify_slice,
sig_alg,
signature
);
}
log::info!("Received CV");
},
TlsState::SERVER_WAIT_FINISHED => { TlsState::SERVER_WAIT_FINISHED => {
// Ensure that it is Finished // Ensure that it is Finished
let might_be_client_finished = repr.handshake.take().unwrap(); let might_be_client_finished = repr.handshake.take().unwrap();

View File

@ -346,6 +346,7 @@ impl<'a> HandshakeData<'a> {
match self { match self {
HandshakeData::ClientHello(data) => data.get_length(), HandshakeData::ClientHello(data) => data.get_length(),
HandshakeData::ServerHello(data) => data.get_length(), HandshakeData::ServerHello(data) => data.get_length(),
HandshakeData::CertificateRequest(cr) => cr.get_length(),
_ => 0, _ => 0,
} }
} }
@ -525,7 +526,6 @@ impl ClientHello {
pub(crate) fn finalise(mut self) -> Self { pub(crate) fn finalise(mut self) -> Self {
let mut sum = 0; let mut sum = 0;
for extension in self.extensions.iter() { for extension in self.extensions.iter() {
// TODO: Add up the extension length
sum += extension.get_length(); sum += extension.get_length();
} }
self.extension_length = sum.try_into().unwrap(); self.extension_length = sum.try_into().unwrap();
@ -995,3 +995,10 @@ pub(crate) struct CertificateRequest<'a> {
pub(crate) extensions_length: u16, pub(crate) extensions_length: u16,
pub(crate) extensions: Vec<Extension>, pub(crate) extensions: Vec<Extension>,
} }
impl<'a> CertificateRequest<'a> {
fn get_length(&self) -> usize {
usize::try_from(self.certificate_request_context_length).unwrap() +
usize::try_from(self.extensions_length).unwrap() + 3
}
}