From 912feac263b121224102f38d10a598d6bf89de56 Mon Sep 17 00:00:00 2001 From: occheung Date: Thu, 15 Oct 2020 22:40:36 +0800 Subject: [PATCH] buffer: init, cipher: init --- src/buffer.rs | 240 ++++++++++++++++++++++++++++++++++++++++ src/cipher.rs | 76 +++++++++++++ src/lib.rs | 2 + src/tls.rs | 298 +------------------------------------------------- 4 files changed, 322 insertions(+), 294 deletions(-) create mode 100644 src/buffer.rs create mode 100644 src/cipher.rs diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..e452d8d --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,240 @@ +use core::cell::RefCell; +use core::convert::{ TryInto, TryFrom }; + +use smoltcp::{ Result, Error }; + +use alloc::vec::Vec; + +use byteorder::{ByteOrder, NetworkEndian, BigEndian}; + +use crate::tls_packet::*; + +// Only designed to support read or write the entire buffer +pub(crate) struct TlsBuffer<'a> { + buffer: &'a mut [u8], + index: RefCell, +} + +impl<'a> Into<&'a [u8]> for TlsBuffer<'a> { + fn into(self) -> &'a [u8] { + &self.buffer[0..self.index.into_inner()] + } +} + +impl<'a> TlsBuffer<'a> { + pub(crate) fn new(buffer: &'a mut [u8]) -> Self { + Self { + buffer, + index: RefCell::new(0), + } + } + + pub(crate) fn get_size(&self) -> usize { + self.index.clone().into_inner() + } + + pub(crate) fn write(&mut self, data: &[u8]) -> Result<()> { + let mut index = self.index.borrow_mut(); + if (self.buffer.len() - *index) < data.len() { + return Err(Error::Exhausted); + } + let next_index = *index + data.len(); + self.buffer[*index..next_index].copy_from_slice(data); + *index = next_index; + Ok(()) + } + + pub(crate) fn write_u8(&mut self, data: u8) -> Result<()> { + let mut index = self.index.borrow_mut(); + if (self.buffer.len() - *index) < 1 { + return Err(Error::Exhausted); + } + self.buffer[*index] = data; + *index += 1; + Ok(()) + } + + pub(crate) fn read_u8(&mut self) -> Result { + let mut index = self.index.borrow_mut(); + if (self.buffer.len() - *index) < 1 { + return Err(Error::Exhausted); + } + let data = self.buffer[*index]; + *index += 1; + Ok(data) + } + + pub(crate) fn read_all(self) -> &'a [u8] { + &self.buffer[self.index.into_inner()..] + } + + pub(crate) fn read_slice(&self, length: usize) -> Result<&[u8]> { + let mut index = self.index.borrow_mut(); + if (self.buffer.len() - *index) < length { + return Err(Error::Exhausted); + } + let next_index = *index + length; + let slice = &self.buffer[*index..next_index]; + *index = next_index; + Ok(slice) + } + + pub(crate) fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr<'a>) -> Result<()> { + self.write_u8(tls_repr.content_type.into())?; + self.write_u16(tls_repr.version.into())?; + self.write_u16(tls_repr.length)?; + if let Some(app_data) = tls_repr.payload { + self.write(app_data)?; + } else if let Some(handshake_repr) = tls_repr.handshake { + // Queue handshake_repr into buffer + self.enqueue_handshake_repr(handshake_repr)?; + } else { + return Err(Error::Malformed); + } + Ok(()) + } + + fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr<'a>) -> Result<()> { + self.write_u8(handshake_repr.msg_type.into())?; + self.write_u24(handshake_repr.length)?; + self.enqueue_handshake_data(handshake_repr.handshake_data) + } + + fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData<'a>) -> Result<()> { + match handshake_data { + HandshakeData::ClientHello(client_hello) => { + self.enqueue_client_hello(client_hello) + } + _ => { + Err(Error::Unrecognized) + } + } + } + + fn enqueue_client_hello(&mut self, client_hello: ClientHello<'a>) -> Result<()> { + self.write_u16(client_hello.version.into())?; + self.write(&client_hello.random)?; + self.write_u8(client_hello.session_id_length)?; + self.write(&client_hello.session_id)?; + self.write_u16(client_hello.cipher_suites_length)?; + for suite in client_hello.cipher_suites.iter() { + self.write_u16((*suite).into())?; + } + self.write_u8(client_hello.compression_method_length)?; + self.write_u8(client_hello.compression_methods)?; + self.write_u16(client_hello.extension_length)?; + self.enqueue_extensions(client_hello.extensions) + } + + fn enqueue_extensions(&mut self, extensions: Vec) -> Result<()> { + for extension in extensions { + self.write_u16(extension.extension_type.into())?; + self.write_u16(extension.length)?; + self.enqueue_extension_data(extension.extension_data)?; + } + Ok(()) + } + + fn enqueue_extension_data(&mut self, extension_data: ExtensionData) -> Result<()> { + use crate::tls_packet::ExtensionData::*; + match extension_data { + SupportedVersions(s) => { + use crate::tls_packet::SupportedVersions::*; + match s { + ClientHello { length, versions } => { + self.write_u8(length)?; + for version in versions.iter() { + self.write_u16((*version).into())?; + } + }, + ServerHello { selected_version } => { + self.write_u16(selected_version.into())?; + } + } + }, + SignatureAlgorithms(s) => { + self.write_u16(s.length)?; + for sig_alg in s.supported_signature_algorithms.iter() { + self.write_u16((*sig_alg).into())?; + } + }, + NegotiatedGroups(n) => { + self.write_u16(n.length)?; + for group in n.named_group_list.iter() { + self.write_u16((*group).into())?; + } + }, + KeyShareEntry(k) => { + let mut key_share_entry_into = |buffer: &mut TlsBuffer, entry: crate::tls_packet::KeyShareEntry| { + buffer.write_u16(entry.group.into())?; + buffer.write_u16(entry.length)?; + buffer.write(entry.key_exchange.as_slice()) + }; + + use crate::tls_packet::KeyShareEntryContent::*; + match k { + KeyShareClientHello { length, client_shares } => { + self.write_u16(length)?; + for share in client_shares.iter() { + self.enqueue_key_share_entry(share)?; + } + } + KeyShareHelloRetryRequest { selected_group } => { + self.write_u16(selected_group.into())?; + } + KeyShareServerHello { server_share } => { + self.enqueue_key_share_entry(&server_share)?; + } + } + }, + + // TODO: Implement buffer formatting for other extensions + _ => todo!() + }; + Ok(()) + } + + fn enqueue_key_share_entry(&mut self, entry: &crate::tls_packet::KeyShareEntry) -> Result<()> { + self.write_u16(entry.group.into())?; + self.write_u16(entry.length)?; + self.write(entry.key_exchange.as_slice()) + } +} + +macro_rules! export_byte_order_fn { + ($($write_fn_name: ident, $read_fn_name: ident, $data_type: ty, $data_size: literal),+) => { + impl<'a> TlsBuffer<'a> { + $( + pub(crate) fn $write_fn_name(&mut self, data: $data_type) -> Result<()> { + let mut index = self.index.borrow_mut(); + if (self.buffer.len() - *index) < $data_size { + return Err(Error::Exhausted); + } + let next_index = *index + $data_size; + NetworkEndian::$write_fn_name(&mut self.buffer[*index..next_index], data); + *index = next_index; + Ok(()) + } + + pub(crate) fn $read_fn_name(&self) -> Result<$data_type> { + let mut index = self.index.borrow_mut(); + if (self.buffer.len() - *index) < $data_size { + return Err(Error::Exhausted); + } + let next_index = *index + $data_size; + let data = NetworkEndian::$read_fn_name(&self.buffer[*index..next_index]); + *index = next_index; + Ok(data) + } + )+ + } + } +} + +export_byte_order_fn!( + write_u16, read_u16, u16, 2, + write_u24, read_u24, u32, 3, + write_u32, read_u32, u32, 4, + write_u48, read_u48, u64, 6, + write_u64, read_u64, u64, 8 +); diff --git a/src/cipher.rs b/src/cipher.rs new file mode 100644 index 0000000..c72a68b --- /dev/null +++ b/src/cipher.rs @@ -0,0 +1,76 @@ +use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret, ecdh::SharedSecret}; +use aes_gcm::{Aes128Gcm, Aes256Gcm}; +use chacha20poly1305::{ChaCha20Poly1305, Key}; +use ccm::{Ccm, consts::*}; +use aes_gcm::aes::Aes128; +use aes_gcm::{AeadInPlace, NewAead}; +use generic_array::GenericArray; +use rand_core::{ RngCore, CryptoRng }; +use alloc::vec::Vec; +use crate::Error as TlsError; + +pub(crate) enum Cipher { + TLS_AES_128_GCM_SHA256(Aes128Gcm), + TLS_AES_256_GCM_SHA384(Aes256Gcm), + TLS_CHACHA20_POLY1305_SHA256(ChaCha20Poly1305), + TLS_AES_128_CCM_SHA256(Ccm) +} + +macro_rules! impl_cipher { + ($($cipher_name: ident),+) => { + impl Cipher { + pub(crate) fn encrypt(&self, rng: &mut T, associated_data: &[u8], buffer: &mut Vec) -> core::result::Result<(), TlsError> + where + T: RngCore + CryptoRng + { + // All 4 supported Ciphers use a nonce of 12 bytes + let mut nonce_array: [u8; 12] = [0; 12]; + rng.fill_bytes(&mut nonce_array); + use Cipher::*; + match self { + $( + $cipher_name(cipher) => { + cipher.encrypt_in_place( + &GenericArray::from_slice(&nonce_array), + associated_data, + buffer + ).map_err( + |_| TlsError::EncryptionError + ) + } + )+ + } + } + + pub(crate) fn decrypt(&self, rng: &mut T, associated_data: &[u8], buffer: &mut Vec) -> core::result::Result<(), TlsError> + where + T: RngCore + CryptoRng + { + // All 4 supported Ciphers use a nonce of 12 bytes + let mut nonce_array: [u8; 12] = [0; 12]; + rng.fill_bytes(&mut nonce_array); + use Cipher::*; + match self { + $( + $cipher_name(cipher) => { + cipher.decrypt_in_place( + &GenericArray::from_slice(&nonce_array), + associated_data, + buffer + ).map_err( + |_| TlsError::EncryptionError + ) + } + )+ + } + } + } + } +} + +impl_cipher!( + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_CCM_SHA256 +); \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 94a3e72..2d37ced 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,8 @@ extern crate alloc; pub mod tls; pub mod tls_packet; pub mod parse; +pub mod cipher; +pub mod buffer; // TODO: Implement errors // Details: Encapsulate smoltcp & nom errors diff --git a/src/tls.rs b/src/tls.rs index abf33e8..a7249dd 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -34,6 +34,8 @@ use alloc::vec::{ self, Vec }; use crate::Error as TlsError; use crate::tls_packet::*; use crate::parse::parse_tls_repr; +use crate::cipher::Cipher; +use crate::buffer::TlsBuffer; #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[allow(non_camel_case_types)] @@ -60,72 +62,6 @@ pub struct TlsSocket cipher: RefCell>, } -pub(crate) enum Cipher { - TLS_AES_128_GCM_SHA256(Aes128Gcm), - TLS_AES_256_GCM_SHA384(Aes256Gcm), - TLS_CHACHA20_POLY1305_SHA256(ChaCha20Poly1305), - TLS_AES_128_CCM_SHA256(Ccm) -} - -macro_rules! impl_cipher { - ($($cipher_name: ident),+) => { - impl Cipher { - pub(crate) fn encrypt(&self, rng: &mut T, associated_data: &[u8], buffer: &mut Vec) -> core::result::Result<(), TlsError> - where - T: RngCore + CryptoRng - { - // All 4 supported Ciphers use a nonce of 12 bytes - let mut nonce_array: [u8; 12] = [0; 12]; - rng.fill_bytes(&mut nonce_array); - use Cipher::*; - match self { - $( - $cipher_name(cipher) => { - cipher.encrypt_in_place( - &GenericArray::from_slice(&nonce_array), - associated_data, - buffer - ).map_err( - |_| TlsError::EncryptionError - ) - } - )+ - } - } - - pub(crate) fn decrypt(&self, rng: &mut T, associated_data: &[u8], buffer: &mut Vec) -> core::result::Result<(), TlsError> - where - T: RngCore + CryptoRng - { - // All 4 supported Ciphers use a nonce of 12 bytes - let mut nonce_array: [u8; 12] = [0; 12]; - rng.fill_bytes(&mut nonce_array); - use Cipher::*; - match self { - $( - $cipher_name(cipher) => { - cipher.decrypt_in_place( - &GenericArray::from_slice(&nonce_array), - associated_data, - buffer - ).map_err( - |_| TlsError::EncryptionError - ) - } - )+ - } - } - } - } -} - -impl_cipher!( - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_128_CCM_SHA256 -); - impl TlsSocket { pub fn new<'a, 'b, 'c>( sockets: &mut SocketSet<'a, 'b, 'c>, @@ -396,10 +332,10 @@ impl TlsSocket { let mut array = [0; 2048]; let mut buffer = TlsBuffer::new(&mut array); buffer.enqueue_tls_repr(tls_repr)?; - let buffer_size = buffer.index.clone(); + let buffer_size = buffer.get_size(); tcp_socket.send_slice(buffer.into()) .and_then( - |size| if size == buffer_size.into_inner() { + |size| if size == buffer_size { Ok(()) } else { Err(Error::Truncated) @@ -434,229 +370,3 @@ impl TlsSocket { } } } - -// Only designed to support read or write the entire buffer -pub(crate) struct TlsBuffer<'a> { - buffer: &'a mut [u8], - index: core::cell::RefCell, -} - -impl<'a> Into<&'a [u8]> for TlsBuffer<'a> { - fn into(self) -> &'a [u8] { - &self.buffer[0..self.index.into_inner()] - } -} - -impl<'a> TlsBuffer<'a> { - pub(crate) fn new(buffer: &'a mut [u8]) -> Self { - Self { - buffer, - index: core::cell::RefCell::new(0), - } - } - - pub(crate) fn write(&mut self, data: &[u8]) -> Result<()> { - let mut index = self.index.borrow_mut(); - if (self.buffer.len() - *index) < data.len() { - return Err(Error::Exhausted); - } - let next_index = *index + data.len(); - self.buffer[*index..next_index].copy_from_slice(data); - *index = next_index; - Ok(()) - } - - pub(crate) fn write_u8(&mut self, data: u8) -> Result<()> { - let mut index = self.index.borrow_mut(); - if (self.buffer.len() - *index) < 1 { - return Err(Error::Exhausted); - } - self.buffer[*index] = data; - *index += 1; - Ok(()) - } - - pub(crate) fn read_u8(&mut self) -> Result { - let mut index = self.index.borrow_mut(); - if (self.buffer.len() - *index) < 1 { - return Err(Error::Exhausted); - } - let data = self.buffer[*index]; - *index += 1; - Ok(data) - } - - pub(crate) fn read_all(self) -> &'a [u8] { - &self.buffer[self.index.into_inner()..] - } - - pub(crate) fn read_slice(&self, length: usize) -> Result<&[u8]> { - let mut index = self.index.borrow_mut(); - if (self.buffer.len() - *index) < length { - return Err(Error::Exhausted); - } - let next_index = *index + length; - let slice = &self.buffer[*index..next_index]; - *index = next_index; - Ok(slice) - } - - fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr<'a>) -> Result<()> { - self.write_u8(tls_repr.content_type.into())?; - self.write_u16(tls_repr.version.into())?; - self.write_u16(tls_repr.length)?; - if let Some(app_data) = tls_repr.payload { - self.write(app_data)?; - } else if let Some(handshake_repr) = tls_repr.handshake { - // Queue handshake_repr into buffer - self.enqueue_handshake_repr(handshake_repr)?; - } else { - return Err(Error::Malformed); - } - Ok(()) - } - - fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr<'a>) -> Result<()> { - self.write_u8(handshake_repr.msg_type.into())?; - self.write_u24(handshake_repr.length)?; - self.enqueue_handshake_data(handshake_repr.handshake_data) - } - - fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData<'a>) -> Result<()> { - match handshake_data { - HandshakeData::ClientHello(client_hello) => { - self.enqueue_client_hello(client_hello) - } - _ => { - Err(Error::Unrecognized) - } - } - } - - fn enqueue_client_hello(&mut self, client_hello: ClientHello<'a>) -> Result<()> { - self.write_u16(client_hello.version.into())?; - self.write(&client_hello.random)?; - self.write_u8(client_hello.session_id_length)?; - self.write(&client_hello.session_id)?; - self.write_u16(client_hello.cipher_suites_length)?; - for suite in client_hello.cipher_suites.iter() { - self.write_u16((*suite).into())?; - } - self.write_u8(client_hello.compression_method_length)?; - self.write_u8(client_hello.compression_methods)?; - self.write_u16(client_hello.extension_length)?; - self.enqueue_extensions(client_hello.extensions) - } - - fn enqueue_extensions(&mut self, extensions: Vec) -> Result<()> { - for extension in extensions { - self.write_u16(extension.extension_type.into())?; - self.write_u16(extension.length)?; - self.enqueue_extension_data(extension.extension_data)?; - } - Ok(()) - } - - fn enqueue_extension_data(&mut self, extension_data: ExtensionData) -> Result<()> { - use crate::tls_packet::ExtensionData::*; - match extension_data { - SupportedVersions(s) => { - use crate::tls_packet::SupportedVersions::*; - match s { - ClientHello { length, versions } => { - self.write_u8(length)?; - for version in versions.iter() { - self.write_u16((*version).into())?; - } - }, - ServerHello { selected_version } => { - self.write_u16(selected_version.into())?; - } - } - }, - SignatureAlgorithms(s) => { - self.write_u16(s.length)?; - for sig_alg in s.supported_signature_algorithms.iter() { - self.write_u16((*sig_alg).into())?; - } - }, - NegotiatedGroups(n) => { - self.write_u16(n.length)?; - for group in n.named_group_list.iter() { - self.write_u16((*group).into())?; - } - }, - KeyShareEntry(k) => { - let mut key_share_entry_into = |buffer: &mut TlsBuffer, entry: crate::tls_packet::KeyShareEntry| { - buffer.write_u16(entry.group.into())?; - buffer.write_u16(entry.length)?; - buffer.write(entry.key_exchange.as_slice()) - }; - - use crate::tls_packet::KeyShareEntryContent::*; - match k { - KeyShareClientHello { length, client_shares } => { - self.write_u16(length)?; - for share in client_shares.iter() { - self.enqueue_key_share_entry(share)?; - } - } - KeyShareHelloRetryRequest { selected_group } => { - self.write_u16(selected_group.into())?; - } - KeyShareServerHello { server_share } => { - self.enqueue_key_share_entry(&server_share)?; - } - } - }, - - // TODO: Implement buffer formatting for other extensions - _ => todo!() - }; - Ok(()) - } - - fn enqueue_key_share_entry(&mut self, entry: &crate::tls_packet::KeyShareEntry) -> Result<()> { - self.write_u16(entry.group.into())?; - self.write_u16(entry.length)?; - self.write(entry.key_exchange.as_slice()) - } -} - -macro_rules! export_byte_order_fn { - ($($write_fn_name: ident, $read_fn_name: ident, $data_type: ty, $data_size: literal),+) => { - impl<'a> TlsBuffer<'a> { - $( - pub(crate) fn $write_fn_name(&mut self, data: $data_type) -> Result<()> { - let mut index = self.index.borrow_mut(); - if (self.buffer.len() - *index) < $data_size { - return Err(Error::Exhausted); - } - let next_index = *index + $data_size; - NetworkEndian::$write_fn_name(&mut self.buffer[*index..next_index], data); - *index = next_index; - Ok(()) - } - - pub(crate) fn $read_fn_name(&self) -> Result<$data_type> { - let mut index = self.index.borrow_mut(); - if (self.buffer.len() - *index) < $data_size { - return Err(Error::Exhausted); - } - let next_index = *index + $data_size; - let data = NetworkEndian::$read_fn_name(&self.buffer[*index..next_index]); - *index = next_index; - Ok(data) - } - )+ - } - } -} - -export_byte_order_fn!( - write_u16, read_u16, u16, 2, - write_u24, read_u24, u32, 3, - write_u32, read_u32, u32, 4, - write_u48, read_u48, u64, 6, - write_u64, read_u64, u64, 8 -);