Stage machine: init

This commit is contained in:
occheung 2020-10-14 17:37:45 +08:00
parent 337edf7411
commit a03a511756
4 changed files with 328 additions and 199 deletions

View File

@ -10,7 +10,8 @@ chacha20poly1305 = "0.6.0"
sha2 = { version = "0.9.1", default-features = false } sha2 = { version = "0.9.1", default-features = false }
byteorder = { version = "1.3.4", default-features = false } byteorder = { version = "1.3.4", default-features = false }
num_enum = { version = "0.5.1", default-features = false } num_enum = { version = "0.5.1", default-features = false }
log = {version = "0.4.11"} log = "0.4.11"
generic-array = "0.14.4"
[dependencies.smoltcp] [dependencies.smoltcp]
version = "0.6.0" version = "0.6.0"

View File

@ -32,21 +32,23 @@ pub(crate) fn parse_tls_repr(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
payload: None, payload: None,
handshake: None, handshake: None,
}; };
let (rest, bytes) = take(repr.length)(rest)?;
{ {
use crate::tls_packet::TlsContentType::*; use crate::tls_packet::TlsContentType::*;
match repr.content_type { match repr.content_type {
Handshake => { Handshake => {
let (rest, handshake) = parse_handshake(rest)?; let (rest, handshake) = complete(
parse_handshake
)(bytes)?;
repr.handshake = Some(handshake); repr.handshake = Some(handshake);
Ok((rest, repr))
}, },
_ => { ChangeCipherSpec | ApplicationData => {
let (rest, payload) = take(repr.length)(rest)?; repr.payload = Some(bytes);
repr.payload = Some(payload); },
_ => todo!()
}
}
Ok((rest, repr)) Ok((rest, repr))
}
}
}
} }
fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> { fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> {
@ -65,7 +67,7 @@ fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> {
use crate::tls_packet::HandshakeType::*; use crate::tls_packet::HandshakeType::*;
match repr.msg_type { match repr.msg_type {
ServerHello => { ServerHello => {
let (rest, data) = parse_server_hello(bytes)?; let (rest, data) = parse_server_hello(rest)?;
repr.handshake_data = data; repr.handshake_data = data;
Ok((rest, repr)) Ok((rest, repr))
}, },
@ -105,8 +107,8 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> {
let mut extension_vec: Vec<Extension> = Vec::new(); let mut extension_vec: Vec<Extension> = Vec::new();
let mut extension_length: i32 = server_hello.extension_length.into(); let mut extension_length: i32 = server_hello.extension_length.into();
while extension_length >= 0 { while extension_length > 0 {
let (rem, extension) = parse_extension(rest)?; let (rem, extension) = parse_extension(rest, HandshakeType::ServerHello)?;
rest = rem; rest = rem;
extension_length -= i32::try_from(extension.get_length()).unwrap(); extension_length -= i32::try_from(extension.get_length()).unwrap();
@ -114,27 +116,92 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> {
if extension_length < 0 { if extension_length < 0 {
todo!() todo!()
} }
extension_vec.push(extension);
} }
server_hello.extensions = extension_vec; server_hello.extensions = extension_vec;
Ok((rest, HandshakeData::ServerHello(server_hello))) Ok((rest, HandshakeData::ServerHello(server_hello)))
} }
fn parse_extension(bytes: &[u8]) -> IResult<&[u8], Extension> { fn parse_extension(bytes: &[u8], handshake_type: HandshakeType) -> IResult<&[u8], Extension> {
let extension_type = take(2_usize); let extension_type = take(2_usize);
let length = take(2_usize); let length = take(2_usize);
let (rest, (extension_type, length)) = let (rest, (extension_type, length)) =
tuple((extension_type, length))(bytes)?; tuple((extension_type, length))(bytes)?;
let extension_type = ExtensionType::try_from(
NetworkEndian::read_u16(extension_type)
).unwrap();
let length = NetworkEndian::read_u16(length); let length = NetworkEndian::read_u16(length);
let (rest, extension_data) = take(length)(rest)?; // Process extension data according to extension_type
// TODO: Deal with HelloRetryRequest
let (rest, extension_data) = {
use ExtensionType::*;
match extension_type {
SupportedVersions => {
match handshake_type {
HandshakeType::ClientHello => {
todo!()
},
HandshakeType::ServerHello => {
let (rest, selected_version) = take(2_usize)(rest)?;
let selected_version = TlsVersion::try_from(
NetworkEndian::read_u16(selected_version)
).unwrap();
(
rest,
ExtensionData::SupportedVersions(
crate::tls_packet::SupportedVersions::ServerHello {
selected_version
}
)
)
},
_ => todo!()
}
},
KeyShare => {
match handshake_type {
HandshakeType::ClientHello => {
todo!()
},
HandshakeType::ServerHello => {
let group = take(2_usize);
let length = take(2_usize);
let (rest, (group, length)) =
tuple((group, length))(rest)?;
let group = NamedGroup::try_from(
NetworkEndian::read_u16(group)
).unwrap();
let length = NetworkEndian::read_u16(length);
let (rest, key_exchange_slice) = take(length)(rest)?;
let mut key_exchange = Vec::new();
key_exchange.extend_from_slice(key_exchange_slice);
let server_share = KeyShareEntry {
group,
length,
key_exchange,
};
let key_share_sh = crate::tls_packet::KeyShareEntryContent::KeyShareServerHello {
server_share
};
(rest, ExtensionData::KeyShareEntry(key_share_sh))
},
_ => todo!()
}
},
_ => todo!()
}
};
Ok(( Ok((
rest, rest,
Extension { Extension {
extension_type: ExtensionType::try_from(NetworkEndian::read_u16(extension_type)).unwrap(), extension_type,
length, length,
extension_data extension_data
} }

View File

@ -15,12 +15,13 @@ use smoltcp::time::Instant;
use smoltcp::phy::Device; use smoltcp::phy::Device;
use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use generic_array::GenericArray;
use core::convert::TryInto; use core::convert::TryInto;
use core::convert::TryFrom; use core::convert::TryFrom;
use rand_core::{RngCore, CryptoRng}; use rand_core::{RngCore, CryptoRng};
use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret}; use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret, ecdh::SharedSecret};
use alloc::vec::{ self, Vec }; use alloc::vec::{ self, Vec };
@ -45,6 +46,10 @@ pub struct TlsSocket<R: 'static + RngCore + CryptoRng>
state: TlsState, state: TlsState,
tcp_handle: SocketHandle, tcp_handle: SocketHandle,
rng: R, rng: R,
secret: Option<EphemeralSecret>, // Used enum Option to allow later init
session_id: Option<[u8; 32]>, // init session specific field later
cipher_suite: Option<CipherSuite>,
ecdhe_shared: Option<SharedSecret>,
} }
impl<R: RngCore + CryptoRng> TlsSocket<R> { impl<R: RngCore + CryptoRng> TlsSocket<R> {
@ -63,6 +68,10 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
state: TlsState::START, state: TlsState::START,
tcp_handle, tcp_handle,
rng, rng,
secret: None,
session_id: None,
cipher_suite: None,
ecdhe_shared: None,
} }
} }
@ -102,146 +111,170 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
} }
} }
if self.state == TlsState::START { // Handle TLS handshake through TLS states
// // Create TLS representation, length and payload not finalised match self.state {
// let mut random: [u8; 32] = [0; 32]; // Initiate TLS handshake
// self.rng.fill_bytes(&mut random); TlsState::START => {
// let mut session_id: [u8; 32] = [0; 32]; // Prepare field that is randomised,
// self.rng.fill_bytes(&mut session_id); // Supply it to the TLS repr builder.
// let ecdh_secret = EphemeralSecret::random(&mut self.rng);
// let cipher_suites_length = 6; let mut random: [u8; 32] = [0; 32];
// let cipher_suites = [ let mut session_id: [u8; 32] = [0; 32];
// CipherSuite::TLS_AES_128_GCM_SHA256, self.rng.fill_bytes(&mut random);
// CipherSuite::TLS_AES_256_GCM_SHA384, self.rng.fill_bytes(&mut session_id);
// CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
// ];
//
// // Length: to be determined
// let supported_versions_extension = Extension {
// extension_type: ExtensionType::SupportedVersions,
// length: 5,
// extension_data: &[
// 4, // Number of supported versions * 2
// // Need 2 bytes to contain a version
// 0x03, 0x04, // 0x0304: TLS Version 1.3
// 0x03, 0x03, // 0x0303: TLS version 1.2
// ]
// };
//
// let signature_algorithms_extension = Extension {
// extension_type: ExtensionType::SignatureAlgorithms,
// length: 24,
// extension_data: &[
// 0x00, 22, // Length in bytes
// 0x04, 0x03, // ecdsa_secp256r1_sha256
// 0x08, 0x07, // ed25519
// 0x08, 0x09, // rsa_pss_pss_sha256
// 0x04, 0x01, // rsa_pkcs1_sha256
// 0x08, 0x04, // rsa_pss_rsae_sha256
// 0x08, 0x0a, // rsa_pss_pss_sha384
// 0x05, 0x01, // rsa_pkcs1_sha384
// 0x08, 0x05, // rsa_pss_rsae_sha384
// 0x08, 0x0b, // rsa_pss_pss_sha512
// 0x06, 0x01, // rsa_pkcs1_sha512
// 0x08, 0x06, // rsa_pss_rsae_sha512
// ]
// };
//
// let supported_groups_extension = Extension {
// extension_type: ExtensionType::SupportedGroups,
// length: 4,
// extension_data: &[
// 0x00, 0x02, // Length in bytes
// 0x00, 0x17, // secp256r1
// ]
// };
//
// let key_share_extension = Extension {
// extension_type: ExtensionType::KeyShare,
// length: 71,
// extension_data: &{
// let ecdh_secret = unsafe { EphemeralSecret::random(&mut self.rng) };
// let ecdh_public = EncodedPoint::from(&ecdh_secret);
// let x_coor = ecdh_public.x();
// let y_coor = ecdh_public.y().unwrap();
// let mut data: [u8; 71] = [0; 71];
// data[0..2].copy_from_slice(&[0x00, 69]); // Length in bytes
// data[2..4].copy_from_slice(&[0x00, 0x17]); // secp256r1
// data[4..6].copy_from_slice(&[0x00, 65]); // key exchange length
// data[6..7].copy_from_slice(&[0x04]); // Fixed legacy value
// data[7..39].copy_from_slice(&x_coor);
// data[39..71].copy_from_slice(&y_coor);
// data
// }
// };
//
// let psk_key_exchange_modes_extension = Extension {
// extension_type: ExtensionType::PSKKeyExchangeModes,
// length: 2,
// extension_data: &[
// 0x01, // Length in bytes
// 0x01, // psk_dhe_ke
// ]
// };
//
// let mut client_hello = ClientHello {
// version: TlsVersion::Tls12,
// random,
// session_id_length: 32,
// session_id,
// cipher_suites_length,
// cipher_suites: &cipher_suites,
// compression_method_length: 1,
// compression_methods: 0,
// extension_length: supported_versions_extension.get_length().try_into().unwrap(),
// extensions: vec![
// supported_versions_extension,
// signature_algorithms_extension,
// supported_groups_extension,
// psk_key_exchange_modes_extension,
// key_share_extension
// ]
// };
//
// client_hello.extension_length = {
// let mut sum = 0;
// for ext in client_hello.extensions.iter() {
// sum += ext.get_length();
// }
// sum.try_into().unwrap()
// };
//
// let handshake_repr = HandshakeRepr {
// msg_type: HandshakeType::ClientHello,
// length: client_hello.get_length(),
// handshake_data: HandshakeData::ClientHello(client_hello),
// };
//
// let repr = TlsRepr {
// content_type: TlsContentType::Handshake,
// version: TlsVersion::Tls10,
// length: handshake_repr.get_length(),
// payload: None,
// handshake: Some(handshake_repr),
// };
let repr = TlsRepr::new() let repr = TlsRepr::new()
.client_hello(&mut self.rng); .client_hello(&ecdh_secret, random, session_id);
log::info!("{:?}", repr);
self.send_tls_repr(sockets, repr)?; self.send_tls_repr(sockets, repr)?;
// Store session settings, i.e. secret, session_id
self.secret = Some(ecdh_secret);
self.session_id = Some(session_id);
// Update the TLS state
self.state = TlsState::WAIT_SH; self.state = TlsState::WAIT_SH;
Ok(true) },
} else if self.state == TlsState::WAIT_SH { // TLS Client wait for Server Hello
Ok(true) // No need to send anything
} else { TlsState::WAIT_SH => {},
Ok(true) // TLS Client wait for certificate from TLS server
// No need to send anything
// Note: TLS server should normall send SH alongside EE
// TLS client should jump from WAIT_SH directly to WAIT_CERT_CR directly.
TlsState::WAIT_EE => {},
_ => todo!()
} }
// Poll the network interface
iface.poll(sockets, now);
let mut array = [0; 2048];
let tls_repr_vec = self.recv_tls_repr(sockets, &mut array)?;
match self.state {
// During WAIT_SH for a TLS client, client should wait for ServerHello
TlsState::WAIT_SH => {
// "Cached" value.
// Loop forbids mutating the socket itself due to using a self-referenced vector
let mut cipher_suite: Option<CipherSuite> = None;
let mut ecdhe_shared: Option<SharedSecret> = None;
let mut state: TlsState = self.state;
// TLS Packets MUST be received in the same Ethernet frame in such order:
// 1. Server Hello
// 2. Change Cipher Spec
// 3. Encrypted Extensions
for (index, repr) in tls_repr_vec.iter().enumerate() {
// Legacy_protocol must be TLS 1.2
if repr.version != TlsVersion::Tls12 {
// Abort communication
todo!()
}
// TODO: Validate SH
if repr.is_server_hello() {
// Check SH content:
// random: Cannot represent HelloRequestRetry
// (TODO: Support other key shares, e.g. X25519)
// session_id_echo: should be same as the one sent by client
// cipher_suite: Store
// (TODO: Check if such suite was offered)
// compression_method: Must be null, not supported in TLS 1.3
//
// Check extensions:
// supported_version: Must be TLS 1.3
// key_share: Store key, must be in secp256r1
// (TODO: Support other key shares ^)
let handshake_data = &repr.handshake.as_ref().unwrap().handshake_data;
if let HandshakeData::ServerHello(server_hello) = handshake_data {
// Check random: Cannot be SHA-256 of "HelloRetryRequest"
if server_hello.random == HRR_RANDOM {
// Abort communication
todo!()
}
// Check session_id_echo
// The socket should have a session_id after moving from START state
if self.session_id.unwrap() != server_hello.session_id_echo {
// Abort communication
todo!()
}
// Store the cipher suite
cipher_suite = Some(server_hello.cipher_suite);
if server_hello.compression_method != 0 {
// Abort communciation
todo!()
}
for extension in server_hello.extensions.iter() {
if extension.extension_type == ExtensionType::SupportedVersions {
if let ExtensionData::SupportedVersions(
SupportedVersions::ServerHello {
selected_version
}
) = extension.extension_data {
if selected_version != TlsVersion::Tls13 {
// Abort for choosing not offered TLS version
todo!()
}
} else {
// Abort for illegal extension
todo!()
}
}
if extension.extension_type == ExtensionType::KeyShare {
if let ExtensionData::KeyShareEntry(
KeyShareEntryContent::KeyShareServerHello {
server_share
}
) = &extension.extension_data {
// TODO: Use legitimate checking to ensure the chosen
// group is indeed acceptable, when allowing more (EC)DHE
// key sharing
if server_share.group != NamedGroup::secp256r1 {
// Abort for wrong key sharing
todo!()
}
// Store key
// It is surely from secp256r1
// Convert untagged bytes into encoded point on p256 eliptic curve
// Slice the first byte out of the bytes
let server_public = EncodedPoint::from_untagged_bytes(
GenericArray::from_slice(&server_share.key_exchange[1..])
);
// TODO: Handle improper shared key
ecdhe_shared = Some(
self.secret.as_ref().unwrap()
.diffie_hellman(&server_public)
.expect("Unsupported key")
);
}
}
}
state = TlsState::WAIT_EE;
} else {
// Handle invalid TLS packet
todo!()
}
}
}
self.cipher_suite = cipher_suite;
self.ecdhe_shared = ecdhe_shared;
self.state = state;
}
_ => {},
}
Ok(self.state == TlsState::CONNECTED)
} }
// Generic inner send method, through TCP socket // Generic inner send method, through TCP socket
fn send_tls_repr(&mut self, sockets: &mut SocketSet, tls_repr: TlsRepr) -> Result<()> { fn send_tls_repr(&self, sockets: &mut SocketSet, tls_repr: TlsRepr) -> Result<()> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
if !tcp_socket.can_send() {
return Err(Error::Illegal);
}
let mut array = [0; 2048]; let mut array = [0; 2048];
let mut buffer = TlsBuffer::new(&mut array); let mut buffer = TlsBuffer::new(&mut array);
buffer.enqueue_tls_repr(tls_repr)?; buffer.enqueue_tls_repr(tls_repr)?;
@ -257,12 +290,17 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
} }
// Generic inner recv method, through TCP socket // Generic inner recv method, through TCP socket
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<Vec::<TlsRepr>> { // A TCP packet can contain multiple TLS segments
fn recv_tls_repr<'a>(&'a self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<Vec::<TlsRepr>> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.recv_slice(byte_array)?; if !tcp_socket.can_recv() {
return Ok((Vec::new()));
}
let array_size = tcp_socket.recv_slice(byte_array)?;
let mut vec: Vec<TlsRepr> = Vec::new(); let mut vec: Vec<TlsRepr> = Vec::new();
let mut bytes: &[u8] = byte_array; let mut bytes: &[u8] = &byte_array[..array_size];
loop { loop {
match parse_tls_repr(bytes) { match parse_tls_repr(bytes) {
Ok((rest, repr)) => { Ok((rest, repr)) => {
@ -396,7 +434,6 @@ impl<'a> TlsBuffer<'a> {
for extension in extensions { for extension in extensions {
self.write_u16(extension.extension_type.into())?; self.write_u16(extension.extension_type.into())?;
self.write_u16(extension.length)?; self.write_u16(extension.length)?;
// self.write(extension.extension_data)?;
self.enqueue_extension_data(extension.extension_data)?; self.enqueue_extension_data(extension.extension_data)?;
} }
Ok(()) Ok(())
@ -409,7 +446,7 @@ impl<'a> TlsBuffer<'a> {
use crate::tls_packet::SupportedVersions::*; use crate::tls_packet::SupportedVersions::*;
match s { match s {
ClientHello { length, versions } => { ClientHello { length, versions } => {
self.write_u16(length)?; self.write_u8(length)?;
for version in versions.iter() { for version in versions.iter() {
self.write_u16((*version).into())?; self.write_u16((*version).into())?;
} }
@ -432,10 +469,10 @@ impl<'a> TlsBuffer<'a> {
} }
}, },
KeyShareEntry(k) => { KeyShareEntry(k) => {
let key_share_entry_into = |entry: crate::tls_packet::KeyShareEntry| { let mut key_share_entry_into = |buffer: &mut TlsBuffer, entry: crate::tls_packet::KeyShareEntry| {
self.write_u16(entry.group.into())?; buffer.write_u16(entry.group.into())?;
self.write_u16(entry.length)?; buffer.write_u16(entry.length)?;
self.write(entry.key_exchange.as_slice()) buffer.write(entry.key_exchange.as_slice())
}; };
use crate::tls_packet::KeyShareEntryContent::*; use crate::tls_packet::KeyShareEntryContent::*;
@ -443,14 +480,14 @@ impl<'a> TlsBuffer<'a> {
KeyShareClientHello { length, client_shares } => { KeyShareClientHello { length, client_shares } => {
self.write_u16(length)?; self.write_u16(length)?;
for share in client_shares.iter() { for share in client_shares.iter() {
key_share_entry_into(*share)?; self.enqueue_key_share_entry(share)?;
} }
} }
KeyShareHelloRetryRequest { selected_group } => { KeyShareHelloRetryRequest { selected_group } => {
self.write_u16(selected_group.into())?; self.write_u16(selected_group.into())?;
} }
KeyShareServerHello { server_share } => { KeyShareServerHello { server_share } => {
key_share_entry_into(server_share)?; self.enqueue_key_share_entry(&server_share)?;
} }
} }
}, },
@ -460,6 +497,12 @@ impl<'a> TlsBuffer<'a> {
}; };
Ok(()) Ok(())
} }
fn enqueue_key_share_entry(&mut self, entry: &crate::tls_packet::KeyShareEntry) -> Result<()> {
self.write_u16(entry.group.into())?;
self.write_u16(entry.length)?;
self.write(entry.key_exchange.as_slice())
}
} }
macro_rules! export_byte_order_fn { macro_rules! export_byte_order_fn {

View File

@ -12,6 +12,13 @@ use core::convert::TryInto;
use alloc::vec::Vec; use alloc::vec::Vec;
pub(crate) const HRR_RANDOM: [u8; 32] = [
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
];
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)] #[repr(u8)]
pub(crate) enum TlsContentType { pub(crate) enum TlsContentType {
@ -54,21 +61,35 @@ impl<'a> TlsRepr<'a> {
} }
} }
pub(crate) fn client_hello<T>(mut self, rng: &mut T) -> Self pub(crate) fn client_hello(mut self, secret: &EphemeralSecret, random: [u8; 32], session_id: [u8; 32]) -> Self {
where
T: RngCore + CryptoRng
{
self.content_type = TlsContentType::Handshake; self.content_type = TlsContentType::Handshake;
self.version = TlsVersion::Tls10; self.version = TlsVersion::Tls10;
self.handshake = Some({ let handshake_repr = {
let mut repr = HandshakeRepr::new(); let mut repr = HandshakeRepr::new();
repr.msg_type = HandshakeType::ClientHello;
repr.handshake_data = HandshakeData::ClientHello({ repr.handshake_data = HandshakeData::ClientHello({
ClientHello::new(rng) ClientHello::new(secret, random, session_id)
}); });
repr.length = repr.handshake_data.get_length().try_into().unwrap();
repr repr
}); };
self.length = handshake_repr.get_length();
self.handshake = Some(handshake_repr);
self self
} }
pub(crate) fn is_server_hello(&self) -> bool {
self.content_type == TlsContentType::Handshake &&
self.payload.is_none() &&
self.handshake.is_some() &&
{
if let Some(repr) = &self.handshake {
repr.msg_type == HandshakeType::ServerHello
} else {
false
}
}
}
} }
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
@ -146,7 +167,7 @@ pub(crate) enum HandshakeData<'a> {
} }
impl<'a> HandshakeData<'a> { impl<'a> HandshakeData<'a> {
pub(crate) fn get_length(&self) -> u32 { pub(crate) fn get_length(&self) -> usize {
match self { match self {
HandshakeData::ClientHello(data) => data.get_length(), HandshakeData::ClientHello(data) => data.get_length(),
HandshakeData::ServerHello(data) => todo!(), HandshakeData::ServerHello(data) => todo!(),
@ -156,16 +177,13 @@ impl<'a> HandshakeData<'a> {
} }
impl<'a> ClientHello<'a> { impl<'a> ClientHello<'a> {
pub(self) fn new<T>(rng: &mut T) -> Self pub(self) fn new(secret: &EphemeralSecret, random: [u8; 32], session_id: [u8; 32]) -> Self {
where
T: RngCore + CryptoRng
{
let mut client_hello = ClientHello { let mut client_hello = ClientHello {
version: TlsVersion::Tls12, version: TlsVersion::Tls12,
random: [0; 32], random,
session_id_length: 32, session_id_length: 32,
session_id: [0; 32], session_id,
cipher_suites_length: 0, cipher_suites_length: 6,
cipher_suites: &[ cipher_suites: &[
CipherSuite::TLS_AES_128_GCM_SHA256, CipherSuite::TLS_AES_128_GCM_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384, CipherSuite::TLS_AES_256_GCM_SHA384,
@ -177,14 +195,10 @@ impl<'a> ClientHello<'a> {
extensions: Vec::new(), extensions: Vec::new(),
}; };
rng.fill_bytes(&mut client_hello.random); client_hello.add_ch_supported_versions()
rng.fill_bytes(&mut client_hello.session_id); .add_sig_algs()
.add_client_groups_with_key_shares(secret)
client_hello.add_ch_supported_versions(); .finalise()
client_hello.add_sig_algs();
client_hello.add_client_groups_with_key_shares(&mut rng);
client_hello
} }
pub(crate) fn add_ch_supported_versions(mut self) -> Self { pub(crate) fn add_ch_supported_versions(mut self) -> Self {
@ -263,10 +277,7 @@ impl<'a> ClientHello<'a> {
self self
} }
pub(crate) fn add_client_groups_with_key_shares<T>(mut self, rng: &mut T) -> Self pub(crate) fn add_client_groups_with_key_shares(mut self, ecdh_secret: &EphemeralSecret) -> Self {
where
T: RngCore + CryptoRng
{
// List out all supported groups // List out all supported groups
let mut list = Vec::new(); let mut list = Vec::new();
list.push(NamedGroup::secp256r1); list.push(NamedGroup::secp256r1);
@ -280,9 +291,7 @@ impl<'a> ClientHello<'a> {
let mut key_exchange = Vec::new(); let mut key_exchange = Vec::new();
let key_share_entry = match named_group { let key_share_entry = match named_group {
NamedGroup::secp256r1 => { NamedGroup::secp256r1 => {
let ecdh_secret = EphemeralSecret::random(&mut rng); let ecdh_public = EncodedPoint::from(ecdh_secret);
let ecdh_public = EncodedPoint::from(&ecdh_secret);
let x_coor = ecdh_public.x(); let x_coor = ecdh_public.x();
let y_coor = ecdh_public.y().unwrap(); let y_coor = ecdh_public.y().unwrap();
@ -319,6 +328,7 @@ impl<'a> ClientHello<'a> {
extension_data, extension_data,
}; };
let length = list.len()*2;
let group_list = NamedGroupList { let group_list = NamedGroupList {
length: length.try_into().unwrap(), length: length.try_into().unwrap(),
named_group_list: list, named_group_list: list,
@ -336,19 +346,27 @@ impl<'a> ClientHello<'a> {
self self
} }
pub(crate) fn get_length(&self) -> u32 { pub(crate) fn finalise(mut self) -> Self {
let mut length :u32 = 2; // TlsVersion size let mut sum = 0;
for extension in self.extensions.iter() {
// TODO: Add up the extension length
sum += extension.get_length();
}
self.extension_length = sum.try_into().unwrap();
self
}
pub(crate) fn get_length(&self) -> usize {
let mut length: usize = 2; // TlsVersion size
length += 32; // Random size length += 32; // Random size
length += 1; // Legacy session_id length size length += 1; // Legacy session_id length size
length += 32; // Legacy session_id size length += 32; // Legacy session_id size
length += 2; // Cipher_suites_length size length += 2; // Cipher_suites_length size
length += (self.cipher_suites.len() as u32) * 2; length += self.cipher_suites.len() * 2;
length += 1; length += 1;
length += 1; length += 1;
length += 2; length += 2;
for extension in self.extensions.iter() { length += usize::try_from(self.extension_length).unwrap();
length += (extension.get_length() as u32);
}
length length
} }
} }
@ -439,7 +457,7 @@ impl ExtensionData {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) enum SupportedVersions { pub(crate) enum SupportedVersions {
ClientHello { ClientHello {
length: u16, length: u8,
versions: Vec<TlsVersion>, versions: Vec<TlsVersion>,
}, },
ServerHello { ServerHello {
@ -451,7 +469,7 @@ impl SupportedVersions {
pub(crate) fn get_length(&self) -> usize { pub(crate) fn get_length(&self) -> usize {
match self { match self {
Self::ClientHello { length, versions } => { Self::ClientHello { length, versions } => {
usize::try_from(*length).unwrap() + 2 usize::try_from(*length).unwrap() + 1
} }
Self::ServerHello { selected_version } => 2 Self::ServerHello { selected_version } => 2
} }