diff --git a/Cargo.toml b/Cargo.toml index 67e33b3..847286e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,30 +4,40 @@ version = "0.1.0" authors = ["occheung "] edition = "2018" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] aes-gcm = "0.7.0" chacha20poly1305 = "0.6.0" -byteorder = "1.3.4" -num_enum = "0.5.1" +sha2 = { version = "0.9.1", default-features = false } +byteorder = { version = "1.3.4", default-features = false } +num_enum = { version = "0.5.1", default-features = false } +log = {version = "0.4.11"} [dependencies.smoltcp] version = "0.6.0" default-features = false features = ["proto-ipv4", "proto-ipv6", "socket-tcp"] -[dependencies.rand_chacha] -version = "0.2.2" +[dependencies.rand_core] +version = "0.5.1" default-features = false features = [] -[dependencies.rand] -version = "0.7.3" +[dependencies.p256] +version = "0.5.0" default-features = false -features = ["getrandom"] +features = [ "ecdh", "ecdsa", "arithmetic" ] + +[dependencies.rsa] +git = "https://github.com/RustCrypto/RSA.git" +default-features = false +features = [ "alloc" ] [dependencies.heapless] version = "0.5.6" default-features = false -features = [] \ No newline at end of file +features = [] + +[dependencies.nom] +version = "5.1.2" +default-features = false +features= [ "regex", "lexical" ] diff --git a/src/lib.rs b/src/lib.rs index 41521e7..053aa18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,10 @@ #![no_std] pub mod tls; -pub mod tls_packet; \ No newline at end of file +pub mod tls_packet; +pub mod parse; + +pub enum Error { + PropagatedError(smoltcp::Error), + ParsingError() +} \ No newline at end of file diff --git a/src/parse.rs b/src/parse.rs new file mode 100644 index 0000000..c8f9aa3 --- /dev/null +++ b/src/parse.rs @@ -0,0 +1,141 @@ +use nom::IResult; +use nom::bytes::complete::take; +use nom::combinator::complete; +use nom::sequence::tuple; +use nom::error::ErrorKind; +use smoltcp::Error; +use smoltcp::Result; + +use byteorder::{ByteOrder, NetworkEndian, BigEndian}; + +use crate::tls_packet::*; +use core::convert::TryFrom; + +use heapless::{ Vec, consts::* }; + +fn parse_tls(bytes: &[u8]) -> IResult<&[u8], TlsRepr> { + let content_type = take(1_usize); + let version = take(2_usize); + let length = take(2_usize); + + let (rest, (content_type, version, length)) = + tuple((content_type, version, length))(bytes)?; + + let mut repr = TlsRepr { + content_type: TlsContentType::try_from(content_type[0]) + .unwrap(), + + version: TlsVersion::try_from(NetworkEndian::read_u16(version)) + .unwrap(), + + length: NetworkEndian::read_u16(length), + payload: None, + handshake: None, + }; + { + use crate::tls_packet::TlsContentType::*; + match repr.content_type { + Handshake => { + let (rest, handshake) = parse_handshake(rest)?; + repr.handshake = Some(handshake); + Ok((rest, repr)) + }, + _ => { + let (rest, payload) = take(repr.length)(rest)?; + repr.payload = Some(payload); + Ok((rest, repr)) + } + } + } +} + +fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> { + let handshake_type = take(1_usize); + let length = take(3_usize); + + let (rest, (handshake_type, length)) = + tuple((handshake_type, length))(bytes)?; + + let mut repr = HandshakeRepr { + msg_type: HandshakeType::try_from(handshake_type[0]).unwrap(), + length: NetworkEndian::read_u24(length), + handshake_data: HandshakeData::Uninitialized, + }; + { + use crate::tls_packet::HandshakeType::*; + match repr.msg_type { + ServerHello => { + todo!() + }, + _ => todo!() + } + } +} + +fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> { + let version = take(2_usize); + let random = take(32_usize); + let session_id_echo_length = take(1_usize); + + let (rest, (version, random, session_id_echo_length)) = + tuple((version, random, session_id_echo_length))(bytes)?; + + let session_id_echo_length = session_id_echo_length[0]; + let (rest, session_id_echo) = take(session_id_echo_length)(rest)?; + + let cipher_suite = take(2_usize); + let compression_method = take(1_usize); + let extension_length = take(2_usize); + + let (mut rest, (cipher_suite, compression_method, extension_length)) = + tuple((cipher_suite, compression_method, extension_length))(rest)?; + + let mut extension_length = NetworkEndian::read_u16(extension_length); + + let mut server_hello = ServerHello { + version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(), + random, + session_id_echo_length: session_id_echo_length, + session_id_echo, + cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(), + compression_method: compression_method[0], + extension_length, + extensions: &[] + }; + + let mut extension_vec: Vec = Vec::new(); + while extension_length >= 0 { + let (rem, extension) = parse_extension(rest)?; + rest = rem; + extension_length -= extension.get_length(); + + // Todo:: Proper error + if extension_vec.push(extension).is_err() || extension_length < 0 { + todo!() + } + } + + server_hello.extensions = extension_vec; + Ok((rest, HandshakeData::ServerHello(server_hello))) +} + +fn parse_extension(bytes: &[u8]) -> IResult<&[u8], Extension> { + let extension_type = take(2_usize); + let length = take(2_usize); + + let (rest, (extension_type, length)) = + tuple((extension_type, length))(bytes)?; + + let length = NetworkEndian::read_u16(length); + + let (rest, extension_data) = take(length)(rest)?; + + Ok(( + rest, + Extension { + extension_type: ExtensionType::try_from(NetworkEndian::read_u16(extension_type)).unwrap(), + length, + extension_data + } + )) +} diff --git a/src/tls.rs b/src/tls.rs index 5b197b7..15c0dc6 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -11,10 +11,7 @@ use smoltcp::wire::IpEndpoint; use smoltcp::Result; use smoltcp::Error; -use byteorder::{ByteOrder, NetworkEndian, BigEndian, WriteBytesExt}; - -use rand::prelude::*; -use rand_chacha::ChaCha20Rng; +use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use heapless::Vec; use heapless::consts::*; @@ -22,7 +19,11 @@ use heapless::consts::*; use core::convert::TryInto; use core::convert::TryFrom; +use rand_core::{RngCore, CryptoRng}; +use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret}; + use crate::tls_packet::*; + #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[allow(non_camel_case_types)] enum TlsState { @@ -36,18 +37,19 @@ enum TlsState { CONNECTED, } -pub struct TlsSocket +pub struct TlsSocket { state: TlsState, tcp_handle: SocketHandle, - random: ChaCha20Rng, + rng: R, } -impl TlsSocket { +impl TlsSocket { pub fn new<'a, 'b, 'c>( sockets: &mut SocketSet<'a, 'b, 'c>, rx_buffer: TcpSocketBuffer<'b>, tx_buffer: TcpSocketBuffer<'b>, + rng: R, ) -> Self where 'b: 'c, @@ -57,7 +59,7 @@ impl TlsSocket { TlsSocket { state: TlsState::START, tcp_handle, - random: ChaCha20Rng::from_entropy(), + rng, } } @@ -72,13 +74,18 @@ impl TlsSocket { U: Into, { let mut tcp_socket = sockets.get::(self.tcp_handle); - tcp_socket.connect(remote_endpoint, local_endpoint) + if tcp_socket.state() == TcpState::Established { + Ok(()) + } else { + tcp_socket.connect(remote_endpoint, local_endpoint) + } } pub fn tls_connect(&mut self, sockets: &mut SocketSet) -> Result { // Check tcp_socket connectivity { - let tcp_socket = sockets.get::(self.tcp_handle); + let mut tcp_socket = sockets.get::(self.tcp_handle); + tcp_socket.set_keep_alive(Some(smoltcp::time::Duration::from_millis(1000))); if tcp_socket.state() != TcpState::Established { return Ok(false); } @@ -87,11 +94,11 @@ impl TlsSocket { if self.state == TlsState::START { // Create TLS representation, length and payload not finalised let mut random: [u8; 32] = [0; 32]; + self.rng.fill_bytes(&mut random); let mut session_id: [u8; 32] = [0; 32]; - self.random.fill_bytes(&mut random); - self.random.fill_bytes(&mut session_id); + self.rng.fill_bytes(&mut session_id); - let cipher_suites_length = 3; + let cipher_suites_length = 6; let cipher_suites = [ CipherSuite::TLS_AES_128_GCM_SHA256, CipherSuite::TLS_AES_256_GCM_SHA384, @@ -101,15 +108,72 @@ impl TlsSocket { // Length: to be determined let supported_versions_extension = Extension { extension_type: ExtensionType::SupportedVersions, - length: 3, + length: 5, extension_data: &[ - 2, // Number of supported versions * 2 + 4, // Number of supported versions * 2 // Need 2 bytes to contain a version - 0x03, 0x04 // 0x0303: TLS Version 1.3 + 0x03, 0x04, // 0x0304: TLS Version 1.3 + 0x03, 0x03, // 0x0303: TLS version 1.2 ] }; - let client_hello = ClientHello { + let signature_algorithms_extension = Extension { + extension_type: ExtensionType::SignatureAlgorithms, + length: 24, + extension_data: &[ + 0x00, 22, // Length in bytes + 0x04, 0x03, // ecdsa_secp256r1_sha256 + 0x08, 0x07, // ed25519 + 0x08, 0x09, // rsa_pss_pss_sha256 + 0x04, 0x01, // rsa_pkcs1_sha256 + 0x08, 0x04, // rsa_pss_rsae_sha256 + 0x08, 0x0a, // rsa_pss_pss_sha384 + 0x05, 0x01, // rsa_pkcs1_sha384 + 0x08, 0x05, // rsa_pss_rsae_sha384 + 0x08, 0x0b, // rsa_pss_pss_sha512 + 0x06, 0x01, // rsa_pkcs1_sha512 + 0x08, 0x06, // rsa_pss_rsae_sha512 + ] + }; + + let supported_groups_extension = Extension { + extension_type: ExtensionType::SupportedGroups, + length: 4, + extension_data: &[ + 0x00, 0x02, // Length in bytes + 0x00, 0x17, // secp256r1 + ] + }; + + let key_share_extension = Extension { + extension_type: ExtensionType::KeyShare, + length: 71, + extension_data: &{ + let ecdh_secret = unsafe { EphemeralSecret::random(&mut self.rng) }; + let ecdh_public = EncodedPoint::from(&ecdh_secret); + let x_coor = ecdh_public.x(); + let y_coor = ecdh_public.y().unwrap(); + let mut data: [u8; 71] = [0; 71]; + data[0..2].copy_from_slice(&[0x00, 69]); // Length in bytes + data[2..4].copy_from_slice(&[0x00, 0x17]); // secp256r1 + data[4..6].copy_from_slice(&[0x00, 65]); // key exchange length + data[6..7].copy_from_slice(&[0x04]); // Fixed legacy value + data[7..39].copy_from_slice(&x_coor); + data[39..71].copy_from_slice(&y_coor); + data + } + }; + + let psk_key_exchange_modes_extension = Extension { + extension_type: ExtensionType::PSKKeyExchangeModes, + length: 2, + extension_data: &[ + 0x01, // Length in bytes + 0x01, // psk_dhe_ke + ] + }; + + let mut client_hello = ClientHello { version: TlsVersion::Tls12, random, session_id_length: 32, @@ -119,7 +183,21 @@ impl TlsSocket { compression_method_length: 1, compression_methods: 0, extension_length: supported_versions_extension.get_length(), - extensions: &[supported_versions_extension], + extensions: &[ + supported_versions_extension, + signature_algorithms_extension, + supported_groups_extension, + psk_key_exchange_modes_extension, + key_share_extension, + ], + }; + + client_hello.extension_length = { + let mut sum = 0; + for ext in client_hello.extensions.iter() { + sum += ext.get_length(); + } + sum }; let handshake_repr = HandshakeRepr { @@ -130,12 +208,14 @@ impl TlsSocket { let repr = TlsRepr { content_type: TlsContentType::Handshake, - version: TlsVersion::Tls13, - length: 0, + version: TlsVersion::Tls10, + length: handshake_repr.get_length(), payload: None, handshake: Some(handshake_repr), }; + log::info!("{:?}", repr); + self.send_tls_repr(sockets, repr)?; self.state = TlsState::WAIT_SH; Ok(true) @@ -164,11 +244,10 @@ impl TlsSocket { } // Generic inner recv method, through TCP socket - fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { + fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { let mut tcp_socket = sockets.get::(self.tcp_handle); let size = tcp_socket.recv_slice(byte_array)?; - let buffer = TlsBuffer::new(&mut byte_array[..size]); - buffer.dequeue_tls_repr() + todo!() } } @@ -293,121 +372,6 @@ impl<'a> TlsBuffer<'a> { } Ok(()) } - - fn dequeue_tls_repr<'b>(mut self) -> Result> { - // Create a TLS Representation layer - // Modify the representation along the way - let mut repr = TlsRepr { - content_type: TlsContentType::Invalid, - version: TlsVersion::Tls10, - length: 0, - payload: None, - handshake: None, - }; - - repr.content_type = TlsContentType::try_from(self.read_u8()?) - .map_err(|_| Error::Unrecognized)?; - repr.version = TlsVersion::try_from(self.read_u16()?) - .map_err(|_| Error::Unrecognized)?; - repr.length = self.read_u16()?; - - use TlsContentType::*; - match repr.content_type { - Invalid => Err(Error::Unrecognized), - ChangeCipherSpec | Alert => unimplemented!(), - Handshake => todo!(), - ApplicationData => { - repr.payload = Some(self.read_all()); - Ok(repr) - } - } - } - - fn dequeue_handshake<'b>(mut self) -> Result> { - // Create a Handshake header representation - // Fill in proper value afterwards - let mut repr = HandshakeRepr { - msg_type: HandshakeType::ClientHello, - length: 0, - handshake_data: HandshakeData::Uninitialized, - }; - - repr.msg_type = HandshakeType::try_from(self.read_u8()?) - .map_err(|_| Error::Unrecognized)?; - repr.length = self.read_u24()?; - - use HandshakeType::*; - match repr.msg_type { - ClientHello => unimplemented!(), - ServerHello => todo!(), - _ => unimplemented!(), - } - } - - fn dequeue_server_hello(mut self) -> Result> { - // Create a Server Hello representation - // Fill in proper value afterwards - let mut server_hello = ServerHello { - version: TlsVersion::Tls10, - random: [0; 32], - session_id_echo_length: 0, - session_id_echo: [0; 32], - cipher_suite: CipherSuite::TLS_CHACHA20_POLY1305_SHA256, - compression_method: 0, - extension_length: 0, - extensions: &[], - }; - - server_hello.version = TlsVersion::try_from(self.read_u16()?) - .map_err(|_| Error::Unrecognized)?; - for random_byte in &mut server_hello.random[..] { - *random_byte = self.read_u8()?; - } - server_hello.session_id_echo_length = self.read_u8()?; - for id_byte in &mut server_hello.session_id_echo[ - ..usize::try_from(server_hello.session_id_echo_length) - .map_err(|_| Error::Exhausted)? - ] { - *id_byte = self.read_u8()?; - } - server_hello.cipher_suite = CipherSuite::try_from(self.read_u16()?) - .map_err(|_| Error::Unrecognized)?; - server_hello.compression_method = self.read_u8()?; - server_hello.extension_length = self.read_u16()?; - - let mut remaining_length = server_hello.extension_length; - let mut extension_counter = 0; - let mut extension_vec: Vec = Vec::new(); - while remaining_length != 0 { - extension_vec.push(self.dequeue_extension()?.clone()) - .map_err(|_| Error::Exhausted)?; - // Deduct base length of an extension (ext_type, len) - remaining_length -= 4; - remaining_length -= extension_vec[extension_counter].length; - extension_counter += 1; - } - - Ok(server_hello) - } - - fn dequeue_extension(&self) -> Result> { - // Create an Extension representation - // Fill in proper value afterwards - let mut extension = Extension { - extension_type: ExtensionType::ServerName, - length: 0, - extension_data: &[], - }; - - extension.extension_type = ExtensionType::try_from(self.read_u16()?) - .map_err(|_| Error::Unrecognized)?; - extension.length = self.read_u16()?; - extension.extension_data = self.read_slice( - usize::try_from(extension.length) - .map_err(|_| Error::Exhausted)? - )?; - Ok(extension) - } } macro_rules! export_byte_order_fn { diff --git a/src/tls_packet.rs b/src/tls_packet.rs index fc68776..4ec1bca 100644 --- a/src/tls_packet.rs +++ b/src/tls_packet.rs @@ -2,10 +2,12 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use num_enum::IntoPrimitive; use num_enum::TryFromPrimitive; use core::convert::TryFrom; +use core::convert::TryInto; #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u8)] pub(crate) enum TlsContentType { + #[num_enum(default)] Invalid = 0, ChangeCipherSpec = 20, Alert = 21, @@ -16,24 +18,28 @@ pub(crate) enum TlsContentType { #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u16)] pub(crate) enum TlsVersion { + #[num_enum(default)] + Unknown = 0x0000, Tls10 = 0x0301, Tls11 = 0x0302, Tls12 = 0x0303, Tls13 = 0x0304, } -#[derive(Clone, Copy)] -pub(crate) struct TlsRepr<'a, 'b> { +#[derive(Debug, Clone, Copy)] +pub(crate) struct TlsRepr<'a> { pub(crate) content_type: TlsContentType, pub(crate) version: TlsVersion, pub(crate) length: u16, pub(crate) payload: Option<&'a[u8]>, - pub(crate) handshake: Option> + pub(crate) handshake: Option> } #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u8)] pub(crate) enum HandshakeType { + #[num_enum(default)] + Unknown = 0, ClientHello = 1, ServerHello = 2, NewSessionTicket = 4, @@ -47,11 +53,20 @@ pub(crate) enum HandshakeType { MessageHash = 254, } -#[derive(Clone, Copy)] -pub(crate) struct HandshakeRepr<'a, 'b> { +#[derive(Debug, Clone, Copy)] +pub(crate) struct HandshakeRepr<'a> { pub(crate) msg_type: HandshakeType, pub(crate) length: u32, - pub(crate) handshake_data: HandshakeData<'a, 'b>, + pub(crate) handshake_data: HandshakeData<'a>, +} + +impl<'a, 'b> HandshakeRepr<'a> { + pub(crate) fn get_length(&self) -> u16 { + let mut length :u16 = 1; // Handshake Type + length += 3; // Length of Handshake data + length += u16::try_from(self.handshake_data.get_length()).unwrap(); + length + } } #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] @@ -64,8 +79,8 @@ pub(crate) enum CipherSuite { TLS_AES_128_CCM_8_SHA256 = 0x1305, } -#[derive(Clone, Copy)] -pub(crate) struct ClientHello<'a, 'b> { +#[derive(Debug, Clone, Copy)] +pub(crate) struct ClientHello<'a> { pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303) pub(crate) random: [u8; 32], pub(crate) session_id_length: u8, // Legacy: Keep it 32 @@ -75,17 +90,27 @@ pub(crate) struct ClientHello<'a, 'b> { pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0 pub(crate) extension_length: u16, - pub(crate) extensions: &'a[Extension<'b>], + pub(crate) extensions: &'a[Extension<'a>], } -#[derive(Clone, Copy)] -pub(crate) enum HandshakeData<'a, 'b> { +#[derive(Debug, Clone, Copy)] +pub(crate) enum HandshakeData<'a> { Uninitialized, - ClientHello(ClientHello<'a, 'b>), - ServerHello(ServerHello<'a, 'b>), + ClientHello(ClientHello<'a>), + ServerHello(ServerHello<'a>), } -impl<'a, 'b> ClientHello<'a, 'b> { +impl<'a> HandshakeData<'a> { + pub(crate) fn get_length(&self) -> u32 { + match self { + HandshakeData::ClientHello(data) => data.get_length(), + HandshakeData::ServerHello(data) => todo!(), + _ => 0, + } + } +} + +impl<'a> ClientHello<'a> { pub(crate) fn get_length(&self) -> u32 { let mut length :u32 = 2; // TlsVersion size length += 32; // Random size @@ -103,16 +128,16 @@ impl<'a, 'b> ClientHello<'a, 'b> { } } -#[derive(Clone, Copy)] -pub(crate) struct ServerHello<'a, 'b> { +#[derive(Debug, Clone, Copy)] +pub(crate) struct ServerHello<'a> { pub(crate) version: TlsVersion, - pub(crate) random: [u8; 32], + pub(crate) random: &'a[u8], pub(crate) session_id_echo_length: u8, - pub(crate) session_id_echo: [u8; 32], + pub(crate) session_id_echo: &'a[u8], pub(crate) cipher_suite: CipherSuite, pub(crate) compression_method: u8, // Always 0 pub(crate) extension_length: u16, - pub(crate) extensions: &'a[Extension<'b>], + pub(crate) extensions: &'a[Extension<'a>], } #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] @@ -148,7 +173,7 @@ impl ExtensionType { } } -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] pub(crate) struct Extension<'a> { pub(crate) extension_type: ExtensionType, pub(crate) length: u16,