diff --git a/Cargo.toml b/Cargo.toml index b364cff..bd1085b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,8 @@ chacha20poly1305 = "0.6.0" sha2 = { version = "0.9.1", default-features = false } byteorder = { version = "1.3.4", default-features = false } num_enum = { version = "0.5.1", default-features = false } -log = {version = "0.4.11"} +log = "0.4.11" +generic-array = "0.14.4" [dependencies.smoltcp] version = "0.6.0" diff --git a/src/parse.rs b/src/parse.rs index e6c8788..5fe603f 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -32,21 +32,23 @@ pub(crate) fn parse_tls_repr(bytes: &[u8]) -> IResult<&[u8], TlsRepr> { payload: None, handshake: None, }; + let (rest, bytes) = take(repr.length)(rest)?; { use crate::tls_packet::TlsContentType::*; match repr.content_type { Handshake => { - let (rest, handshake) = parse_handshake(rest)?; + let (rest, handshake) = complete( + parse_handshake + )(bytes)?; repr.handshake = Some(handshake); - Ok((rest, repr)) }, - _ => { - let (rest, payload) = take(repr.length)(rest)?; - repr.payload = Some(payload); - Ok((rest, repr)) - } + ChangeCipherSpec | ApplicationData => { + repr.payload = Some(bytes); + }, + _ => todo!() } } + Ok((rest, repr)) } fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> { @@ -65,7 +67,7 @@ fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> { use crate::tls_packet::HandshakeType::*; match repr.msg_type { ServerHello => { - let (rest, data) = parse_server_hello(bytes)?; + let (rest, data) = parse_server_hello(rest)?; repr.handshake_data = data; Ok((rest, repr)) }, @@ -81,7 +83,7 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> { let (rest, (version, random, session_id_echo_length)) = tuple((version, random, session_id_echo_length))(bytes)?; - + let session_id_echo_length = session_id_echo_length[0]; let (rest, session_id_echo) = take(session_id_echo_length)(rest)?; @@ -105,8 +107,8 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> { let mut extension_vec: Vec = Vec::new(); let mut extension_length: i32 = server_hello.extension_length.into(); - while extension_length >= 0 { - let (rem, extension) = parse_extension(rest)?; + while extension_length > 0 { + let (rem, extension) = parse_extension(rest, HandshakeType::ServerHello)?; rest = rem; extension_length -= i32::try_from(extension.get_length()).unwrap(); @@ -114,27 +116,92 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> { if extension_length < 0 { todo!() } + + extension_vec.push(extension); } server_hello.extensions = extension_vec; Ok((rest, HandshakeData::ServerHello(server_hello))) } -fn parse_extension(bytes: &[u8]) -> IResult<&[u8], Extension> { +fn parse_extension(bytes: &[u8], handshake_type: HandshakeType) -> IResult<&[u8], Extension> { let extension_type = take(2_usize); let length = take(2_usize); let (rest, (extension_type, length)) = tuple((extension_type, length))(bytes)?; + let extension_type = ExtensionType::try_from( + NetworkEndian::read_u16(extension_type) + ).unwrap(); let length = NetworkEndian::read_u16(length); - - let (rest, extension_data) = take(length)(rest)?; + + // Process extension data according to extension_type + // TODO: Deal with HelloRetryRequest + let (rest, extension_data) = { + use ExtensionType::*; + match extension_type { + SupportedVersions => { + match handshake_type { + HandshakeType::ClientHello => { + todo!() + }, + HandshakeType::ServerHello => { + let (rest, selected_version) = take(2_usize)(rest)?; + let selected_version = TlsVersion::try_from( + NetworkEndian::read_u16(selected_version) + ).unwrap(); + ( + rest, + ExtensionData::SupportedVersions( + crate::tls_packet::SupportedVersions::ServerHello { + selected_version + } + ) + ) + }, + _ => todo!() + } + }, + KeyShare => { + match handshake_type { + HandshakeType::ClientHello => { + todo!() + }, + HandshakeType::ServerHello => { + let group = take(2_usize); + let length = take(2_usize); + let (rest, (group, length)) = + tuple((group, length))(rest)?; + let group = NamedGroup::try_from( + NetworkEndian::read_u16(group) + ).unwrap(); + let length = NetworkEndian::read_u16(length); + let (rest, key_exchange_slice) = take(length)(rest)?; + let mut key_exchange = Vec::new(); + key_exchange.extend_from_slice(key_exchange_slice); + + let server_share = KeyShareEntry { + group, + length, + key_exchange, + }; + let key_share_sh = crate::tls_packet::KeyShareEntryContent::KeyShareServerHello { + server_share + }; + (rest, ExtensionData::KeyShareEntry(key_share_sh)) + }, + _ => todo!() + } + }, + _ => todo!() + } + }; Ok(( rest, Extension { - extension_type: ExtensionType::try_from(NetworkEndian::read_u16(extension_type)).unwrap(), + extension_type, length, extension_data } diff --git a/src/tls.rs b/src/tls.rs index 1a2c56c..47f385c 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -15,12 +15,13 @@ use smoltcp::time::Instant; use smoltcp::phy::Device; use byteorder::{ByteOrder, NetworkEndian, BigEndian}; +use generic_array::GenericArray; use core::convert::TryInto; use core::convert::TryFrom; use rand_core::{RngCore, CryptoRng}; -use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret}; +use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret, ecdh::SharedSecret}; use alloc::vec::{ self, Vec }; @@ -45,6 +46,10 @@ pub struct TlsSocket state: TlsState, tcp_handle: SocketHandle, rng: R, + secret: Option, // Used enum Option to allow later init + session_id: Option<[u8; 32]>, // init session specific field later + cipher_suite: Option, + ecdhe_shared: Option, } impl TlsSocket { @@ -63,6 +68,10 @@ impl TlsSocket { state: TlsState::START, tcp_handle, rng, + secret: None, + session_id: None, + cipher_suite: None, + ecdhe_shared: None, } } @@ -102,146 +111,170 @@ impl TlsSocket { } } - if self.state == TlsState::START { -// // Create TLS representation, length and payload not finalised -// let mut random: [u8; 32] = [0; 32]; -// self.rng.fill_bytes(&mut random); -// let mut session_id: [u8; 32] = [0; 32]; -// self.rng.fill_bytes(&mut session_id); -// -// let cipher_suites_length = 6; -// let cipher_suites = [ -// CipherSuite::TLS_AES_128_GCM_SHA256, -// CipherSuite::TLS_AES_256_GCM_SHA384, -// CipherSuite::TLS_CHACHA20_POLY1305_SHA256, -// ]; -// -// // Length: to be determined -// let supported_versions_extension = Extension { -// extension_type: ExtensionType::SupportedVersions, -// length: 5, -// extension_data: &[ -// 4, // Number of supported versions * 2 -// // Need 2 bytes to contain a version -// 0x03, 0x04, // 0x0304: TLS Version 1.3 -// 0x03, 0x03, // 0x0303: TLS version 1.2 -// ] -// }; -// -// let signature_algorithms_extension = Extension { -// extension_type: ExtensionType::SignatureAlgorithms, -// length: 24, -// extension_data: &[ -// 0x00, 22, // Length in bytes -// 0x04, 0x03, // ecdsa_secp256r1_sha256 -// 0x08, 0x07, // ed25519 -// 0x08, 0x09, // rsa_pss_pss_sha256 -// 0x04, 0x01, // rsa_pkcs1_sha256 -// 0x08, 0x04, // rsa_pss_rsae_sha256 -// 0x08, 0x0a, // rsa_pss_pss_sha384 -// 0x05, 0x01, // rsa_pkcs1_sha384 -// 0x08, 0x05, // rsa_pss_rsae_sha384 -// 0x08, 0x0b, // rsa_pss_pss_sha512 -// 0x06, 0x01, // rsa_pkcs1_sha512 -// 0x08, 0x06, // rsa_pss_rsae_sha512 -// ] -// }; -// -// let supported_groups_extension = Extension { -// extension_type: ExtensionType::SupportedGroups, -// length: 4, -// extension_data: &[ -// 0x00, 0x02, // Length in bytes -// 0x00, 0x17, // secp256r1 -// ] -// }; -// -// let key_share_extension = Extension { -// extension_type: ExtensionType::KeyShare, -// length: 71, -// extension_data: &{ -// let ecdh_secret = unsafe { EphemeralSecret::random(&mut self.rng) }; -// let ecdh_public = EncodedPoint::from(&ecdh_secret); -// let x_coor = ecdh_public.x(); -// let y_coor = ecdh_public.y().unwrap(); -// let mut data: [u8; 71] = [0; 71]; -// data[0..2].copy_from_slice(&[0x00, 69]); // Length in bytes -// data[2..4].copy_from_slice(&[0x00, 0x17]); // secp256r1 -// data[4..6].copy_from_slice(&[0x00, 65]); // key exchange length -// data[6..7].copy_from_slice(&[0x04]); // Fixed legacy value -// data[7..39].copy_from_slice(&x_coor); -// data[39..71].copy_from_slice(&y_coor); -// data -// } -// }; -// -// let psk_key_exchange_modes_extension = Extension { -// extension_type: ExtensionType::PSKKeyExchangeModes, -// length: 2, -// extension_data: &[ -// 0x01, // Length in bytes -// 0x01, // psk_dhe_ke -// ] -// }; -// -// let mut client_hello = ClientHello { -// version: TlsVersion::Tls12, -// random, -// session_id_length: 32, -// session_id, -// cipher_suites_length, -// cipher_suites: &cipher_suites, -// compression_method_length: 1, -// compression_methods: 0, -// extension_length: supported_versions_extension.get_length().try_into().unwrap(), -// extensions: vec![ -// supported_versions_extension, -// signature_algorithms_extension, -// supported_groups_extension, -// psk_key_exchange_modes_extension, -// key_share_extension -// ] -// }; -// -// client_hello.extension_length = { -// let mut sum = 0; -// for ext in client_hello.extensions.iter() { -// sum += ext.get_length(); -// } -// sum.try_into().unwrap() -// }; -// -// let handshake_repr = HandshakeRepr { -// msg_type: HandshakeType::ClientHello, -// length: client_hello.get_length(), -// handshake_data: HandshakeData::ClientHello(client_hello), -// }; -// -// let repr = TlsRepr { -// content_type: TlsContentType::Handshake, -// version: TlsVersion::Tls10, -// length: handshake_repr.get_length(), -// payload: None, -// handshake: Some(handshake_repr), -// }; - let repr = TlsRepr::new() - .client_hello(&mut self.rng); + // Handle TLS handshake through TLS states + match self.state { + // Initiate TLS handshake + TlsState::START => { + // Prepare field that is randomised, + // Supply it to the TLS repr builder. + let ecdh_secret = EphemeralSecret::random(&mut self.rng); + let mut random: [u8; 32] = [0; 32]; + let mut session_id: [u8; 32] = [0; 32]; + self.rng.fill_bytes(&mut random); + self.rng.fill_bytes(&mut session_id); + let repr = TlsRepr::new() + .client_hello(&ecdh_secret, random, session_id); + self.send_tls_repr(sockets, repr)?; - log::info!("{:?}", repr); + // Store session settings, i.e. secret, session_id + self.secret = Some(ecdh_secret); + self.session_id = Some(session_id); - self.send_tls_repr(sockets, repr)?; - self.state = TlsState::WAIT_SH; - Ok(true) - } else if self.state == TlsState::WAIT_SH { - Ok(true) - } else { - Ok(true) + // Update the TLS state + self.state = TlsState::WAIT_SH; + }, + // TLS Client wait for Server Hello + // No need to send anything + TlsState::WAIT_SH => {}, + // TLS Client wait for certificate from TLS server + // No need to send anything + // Note: TLS server should normall send SH alongside EE + // TLS client should jump from WAIT_SH directly to WAIT_CERT_CR directly. + TlsState::WAIT_EE => {}, + _ => todo!() } + + // Poll the network interface + iface.poll(sockets, now); + + let mut array = [0; 2048]; + let tls_repr_vec = self.recv_tls_repr(sockets, &mut array)?; + + match self.state { + // During WAIT_SH for a TLS client, client should wait for ServerHello + TlsState::WAIT_SH => { + + // "Cached" value. + // Loop forbids mutating the socket itself due to using a self-referenced vector + let mut cipher_suite: Option = None; + let mut ecdhe_shared: Option = None; + let mut state: TlsState = self.state; + + // TLS Packets MUST be received in the same Ethernet frame in such order: + // 1. Server Hello + // 2. Change Cipher Spec + // 3. Encrypted Extensions + for (index, repr) in tls_repr_vec.iter().enumerate() { + // Legacy_protocol must be TLS 1.2 + if repr.version != TlsVersion::Tls12 { + // Abort communication + todo!() + } + + // TODO: Validate SH + if repr.is_server_hello() { + // Check SH content: + // random: Cannot represent HelloRequestRetry + // (TODO: Support other key shares, e.g. X25519) + // session_id_echo: should be same as the one sent by client + // cipher_suite: Store + // (TODO: Check if such suite was offered) + // compression_method: Must be null, not supported in TLS 1.3 + // + // Check extensions: + // supported_version: Must be TLS 1.3 + // key_share: Store key, must be in secp256r1 + // (TODO: Support other key shares ^) + let handshake_data = &repr.handshake.as_ref().unwrap().handshake_data; + if let HandshakeData::ServerHello(server_hello) = handshake_data { + // Check random: Cannot be SHA-256 of "HelloRetryRequest" + if server_hello.random == HRR_RANDOM { + // Abort communication + todo!() + } + // Check session_id_echo + // The socket should have a session_id after moving from START state + if self.session_id.unwrap() != server_hello.session_id_echo { + // Abort communication + todo!() + } + // Store the cipher suite + cipher_suite = Some(server_hello.cipher_suite); + if server_hello.compression_method != 0 { + // Abort communciation + todo!() + } + for extension in server_hello.extensions.iter() { + if extension.extension_type == ExtensionType::SupportedVersions { + if let ExtensionData::SupportedVersions( + SupportedVersions::ServerHello { + selected_version + } + ) = extension.extension_data { + if selected_version != TlsVersion::Tls13 { + // Abort for choosing not offered TLS version + todo!() + } + } else { + // Abort for illegal extension + todo!() + } + } + + if extension.extension_type == ExtensionType::KeyShare { + if let ExtensionData::KeyShareEntry( + KeyShareEntryContent::KeyShareServerHello { + server_share + } + ) = &extension.extension_data { + // TODO: Use legitimate checking to ensure the chosen + // group is indeed acceptable, when allowing more (EC)DHE + // key sharing + if server_share.group != NamedGroup::secp256r1 { + // Abort for wrong key sharing + todo!() + } + // Store key + // It is surely from secp256r1 + // Convert untagged bytes into encoded point on p256 eliptic curve + // Slice the first byte out of the bytes + let server_public = EncodedPoint::from_untagged_bytes( + GenericArray::from_slice(&server_share.key_exchange[1..]) + ); + // TODO: Handle improper shared key + ecdhe_shared = Some( + self.secret.as_ref().unwrap() + .diffie_hellman(&server_public) + .expect("Unsupported key") + ); + } + } + } + state = TlsState::WAIT_EE; + + } else { + // Handle invalid TLS packet + todo!() + } + + } + } + self.cipher_suite = cipher_suite; + self.ecdhe_shared = ecdhe_shared; + self.state = state; + } + _ => {}, + } + + Ok(self.state == TlsState::CONNECTED) } // Generic inner send method, through TCP socket - fn send_tls_repr(&mut self, sockets: &mut SocketSet, tls_repr: TlsRepr) -> Result<()> { + fn send_tls_repr(&self, sockets: &mut SocketSet, tls_repr: TlsRepr) -> Result<()> { let mut tcp_socket = sockets.get::(self.tcp_handle); + if !tcp_socket.can_send() { + return Err(Error::Illegal); + } let mut array = [0; 2048]; let mut buffer = TlsBuffer::new(&mut array); buffer.enqueue_tls_repr(tls_repr)?; @@ -257,12 +290,17 @@ impl TlsSocket { } // Generic inner recv method, through TCP socket - fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { + // A TCP packet can contain multiple TLS segments + fn recv_tls_repr<'a>(&'a self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { let mut tcp_socket = sockets.get::(self.tcp_handle); - tcp_socket.recv_slice(byte_array)?; + if !tcp_socket.can_recv() { + return Ok((Vec::new())); + } + + let array_size = tcp_socket.recv_slice(byte_array)?; let mut vec: Vec = Vec::new(); - let mut bytes: &[u8] = byte_array; + let mut bytes: &[u8] = &byte_array[..array_size]; loop { match parse_tls_repr(bytes) { Ok((rest, repr)) => { @@ -396,7 +434,6 @@ impl<'a> TlsBuffer<'a> { for extension in extensions { self.write_u16(extension.extension_type.into())?; self.write_u16(extension.length)?; -// self.write(extension.extension_data)?; self.enqueue_extension_data(extension.extension_data)?; } Ok(()) @@ -409,7 +446,7 @@ impl<'a> TlsBuffer<'a> { use crate::tls_packet::SupportedVersions::*; match s { ClientHello { length, versions } => { - self.write_u16(length)?; + self.write_u8(length)?; for version in versions.iter() { self.write_u16((*version).into())?; } @@ -432,10 +469,10 @@ impl<'a> TlsBuffer<'a> { } }, KeyShareEntry(k) => { - let key_share_entry_into = |entry: crate::tls_packet::KeyShareEntry| { - self.write_u16(entry.group.into())?; - self.write_u16(entry.length)?; - self.write(entry.key_exchange.as_slice()) + let mut key_share_entry_into = |buffer: &mut TlsBuffer, entry: crate::tls_packet::KeyShareEntry| { + buffer.write_u16(entry.group.into())?; + buffer.write_u16(entry.length)?; + buffer.write(entry.key_exchange.as_slice()) }; use crate::tls_packet::KeyShareEntryContent::*; @@ -443,14 +480,14 @@ impl<'a> TlsBuffer<'a> { KeyShareClientHello { length, client_shares } => { self.write_u16(length)?; for share in client_shares.iter() { - key_share_entry_into(*share)?; + self.enqueue_key_share_entry(share)?; } } KeyShareHelloRetryRequest { selected_group } => { self.write_u16(selected_group.into())?; } KeyShareServerHello { server_share } => { - key_share_entry_into(server_share)?; + self.enqueue_key_share_entry(&server_share)?; } } }, @@ -460,6 +497,12 @@ impl<'a> TlsBuffer<'a> { }; Ok(()) } + + fn enqueue_key_share_entry(&mut self, entry: &crate::tls_packet::KeyShareEntry) -> Result<()> { + self.write_u16(entry.group.into())?; + self.write_u16(entry.length)?; + self.write(entry.key_exchange.as_slice()) + } } macro_rules! export_byte_order_fn { diff --git a/src/tls_packet.rs b/src/tls_packet.rs index 0d367e7..82a90eb 100644 --- a/src/tls_packet.rs +++ b/src/tls_packet.rs @@ -12,6 +12,13 @@ use core::convert::TryInto; use alloc::vec::Vec; +pub(crate) const HRR_RANDOM: [u8; 32] = [ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C +]; + #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u8)] pub(crate) enum TlsContentType { @@ -54,21 +61,35 @@ impl<'a> TlsRepr<'a> { } } - pub(crate) fn client_hello(mut self, rng: &mut T) -> Self - where - T: RngCore + CryptoRng - { + pub(crate) fn client_hello(mut self, secret: &EphemeralSecret, random: [u8; 32], session_id: [u8; 32]) -> Self { self.content_type = TlsContentType::Handshake; self.version = TlsVersion::Tls10; - self.handshake = Some({ + let handshake_repr = { let mut repr = HandshakeRepr::new(); + repr.msg_type = HandshakeType::ClientHello; repr.handshake_data = HandshakeData::ClientHello({ - ClientHello::new(rng) + ClientHello::new(secret, random, session_id) }); + repr.length = repr.handshake_data.get_length().try_into().unwrap(); repr - }); + }; + self.length = handshake_repr.get_length(); + self.handshake = Some(handshake_repr); self } + + pub(crate) fn is_server_hello(&self) -> bool { + self.content_type == TlsContentType::Handshake && + self.payload.is_none() && + self.handshake.is_some() && + { + if let Some(repr) = &self.handshake { + repr.msg_type == HandshakeType::ServerHello + } else { + false + } + } + } } #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] @@ -146,7 +167,7 @@ pub(crate) enum HandshakeData<'a> { } impl<'a> HandshakeData<'a> { - pub(crate) fn get_length(&self) -> u32 { + pub(crate) fn get_length(&self) -> usize { match self { HandshakeData::ClientHello(data) => data.get_length(), HandshakeData::ServerHello(data) => todo!(), @@ -156,16 +177,13 @@ impl<'a> HandshakeData<'a> { } impl<'a> ClientHello<'a> { - pub(self) fn new(rng: &mut T) -> Self - where - T: RngCore + CryptoRng - { + pub(self) fn new(secret: &EphemeralSecret, random: [u8; 32], session_id: [u8; 32]) -> Self { let mut client_hello = ClientHello { version: TlsVersion::Tls12, - random: [0; 32], + random, session_id_length: 32, - session_id: [0; 32], - cipher_suites_length: 0, + session_id, + cipher_suites_length: 6, cipher_suites: &[ CipherSuite::TLS_AES_128_GCM_SHA256, CipherSuite::TLS_AES_256_GCM_SHA384, @@ -177,14 +195,10 @@ impl<'a> ClientHello<'a> { extensions: Vec::new(), }; - rng.fill_bytes(&mut client_hello.random); - rng.fill_bytes(&mut client_hello.session_id); - - client_hello.add_ch_supported_versions(); - client_hello.add_sig_algs(); - client_hello.add_client_groups_with_key_shares(&mut rng); - - client_hello + client_hello.add_ch_supported_versions() + .add_sig_algs() + .add_client_groups_with_key_shares(secret) + .finalise() } pub(crate) fn add_ch_supported_versions(mut self) -> Self { @@ -263,10 +277,7 @@ impl<'a> ClientHello<'a> { self } - pub(crate) fn add_client_groups_with_key_shares(mut self, rng: &mut T) -> Self - where - T: RngCore + CryptoRng - { + pub(crate) fn add_client_groups_with_key_shares(mut self, ecdh_secret: &EphemeralSecret) -> Self { // List out all supported groups let mut list = Vec::new(); list.push(NamedGroup::secp256r1); @@ -280,9 +291,7 @@ impl<'a> ClientHello<'a> { let mut key_exchange = Vec::new(); let key_share_entry = match named_group { NamedGroup::secp256r1 => { - let ecdh_secret = EphemeralSecret::random(&mut rng); - let ecdh_public = EncodedPoint::from(&ecdh_secret); - + let ecdh_public = EncodedPoint::from(ecdh_secret); let x_coor = ecdh_public.x(); let y_coor = ecdh_public.y().unwrap(); @@ -319,6 +328,7 @@ impl<'a> ClientHello<'a> { extension_data, }; + let length = list.len()*2; let group_list = NamedGroupList { length: length.try_into().unwrap(), named_group_list: list, @@ -336,19 +346,27 @@ impl<'a> ClientHello<'a> { self } - pub(crate) fn get_length(&self) -> u32 { - let mut length :u32 = 2; // TlsVersion size + pub(crate) fn finalise(mut self) -> Self { + let mut sum = 0; + for extension in self.extensions.iter() { + // TODO: Add up the extension length + sum += extension.get_length(); + } + self.extension_length = sum.try_into().unwrap(); + self + } + + pub(crate) fn get_length(&self) -> usize { + let mut length: usize = 2; // TlsVersion size length += 32; // Random size length += 1; // Legacy session_id length size length += 32; // Legacy session_id size length += 2; // Cipher_suites_length size - length += (self.cipher_suites.len() as u32) * 2; + length += self.cipher_suites.len() * 2; length += 1; length += 1; length += 2; - for extension in self.extensions.iter() { - length += (extension.get_length() as u32); - } + length += usize::try_from(self.extension_length).unwrap(); length } } @@ -439,7 +457,7 @@ impl ExtensionData { #[derive(Debug, Clone)] pub(crate) enum SupportedVersions { ClientHello { - length: u16, + length: u8, versions: Vec, }, ServerHello { @@ -451,7 +469,7 @@ impl SupportedVersions { pub(crate) fn get_length(&self) -> usize { match self { Self::ClientHello { length, versions } => { - usize::try_from(*length).unwrap() + 2 + usize::try_from(*length).unwrap() + 1 } Self::ServerHello { selected_version } => 2 }