From 70527da8c424479fea2dbd595483230c508dfaf4 Mon Sep 17 00:00:00 2001 From: occheung Date: Sun, 11 Oct 2020 23:41:02 +0800 Subject: [PATCH] tls: insert iface polling --- Cargo.toml | 4 +-- src/lib.rs | 5 ++++ src/main.rs | 35 ++++++++++++++++++++-- src/parse.rs | 23 ++++++++------- src/tls.rs | 56 +++++++++++++++++++++++++---------- src/tls_packet.rs | 75 ++++++++++++++++++++++++++++++++++++++++++----- 6 files changed, 161 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 847286e..30859b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "smoltcp-tls" version = "0.1.0" -authors = ["occheung "] +authors = ["occheung"] edition = "2018" [dependencies] @@ -15,7 +15,7 @@ log = {version = "0.4.11"} [dependencies.smoltcp] version = "0.6.0" default-features = false -features = ["proto-ipv4", "proto-ipv6", "socket-tcp"] +features = ["ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp", "alloc"] [dependencies.rand_core] version = "0.5.1" diff --git a/src/lib.rs b/src/lib.rs index 053aa18..d244e5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,14 @@ #![no_std] +#[macro_use] +extern crate alloc; + pub mod tls; pub mod tls_packet; pub mod parse; +// TODO: Implement errors +// Details: Encapsulate smoltcp & nom errors pub enum Error { PropagatedError(smoltcp::Error), ParsingError() diff --git a/src/main.rs b/src/main.rs index 06686bd..79ce5a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,36 @@ use smoltcp::socket::TcpSocketBuffer; use smoltcp::socket::SocketSet; use smoltcp::wire::Ipv4Address; +use rand_core::RngCore; +use rand_core::CryptoRng; +use rand_core::impls; +use rand_core::Error; + +struct CountingRng(u64); + +impl RngCore for CountingRng { + fn next_u32(&mut self) -> u32 { + self.next_u64() as u32 + } + + fn next_u64(&mut self) -> u64 { + self.0 += 1; + self.0 + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + impls::fill_bytes_via_next(self, dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + Ok(self.fill_bytes(dest)) + } +} + +impl CryptoRng for CountingRng {} + +static mut RNG: CountingRng = CountingRng(0); + fn main() { let mut socket_set_entries: [_; 8] = Default::default(); let mut sockets = SocketSet::new(&mut socket_set_entries[..]); @@ -10,13 +40,14 @@ fn main() { let mut tx_storage = [0; 4096]; let mut rx_storage = [0; 4096]; - let mut tls_socket = { + let mut tls_socket = unsafe { let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]); let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]); TlsSocket::new( &mut sockets, rx_buffer, - tx_buffer + tx_buffer, + &mut RNG, ) }; diff --git a/src/parse.rs b/src/parse.rs index c8f9aa3..e6c8788 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -11,9 +11,9 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use crate::tls_packet::*; use core::convert::TryFrom; -use heapless::{ Vec, consts::* }; +use alloc::vec::Vec; -fn parse_tls(bytes: &[u8]) -> IResult<&[u8], TlsRepr> { +pub(crate) fn parse_tls_repr(bytes: &[u8]) -> IResult<&[u8], TlsRepr> { let content_type = take(1_usize); let version = take(2_usize); let length = take(2_usize); @@ -65,7 +65,9 @@ fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> { use crate::tls_packet::HandshakeType::*; match repr.msg_type { ServerHello => { - todo!() + let (rest, data) = parse_server_hello(bytes)?; + repr.handshake_data = data; + Ok((rest, repr)) }, _ => todo!() } @@ -90,27 +92,26 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> { let (mut rest, (cipher_suite, compression_method, extension_length)) = tuple((cipher_suite, compression_method, extension_length))(rest)?; - let mut extension_length = NetworkEndian::read_u16(extension_length); - let mut server_hello = ServerHello { version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(), random, - session_id_echo_length: session_id_echo_length, + session_id_echo_length, session_id_echo, cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(), compression_method: compression_method[0], - extension_length, - extensions: &[] + extension_length: NetworkEndian::read_u16(extension_length), + extensions: Vec::new(), }; - let mut extension_vec: Vec = Vec::new(); + let mut extension_vec: Vec = Vec::new(); + let mut extension_length: i32 = server_hello.extension_length.into(); while extension_length >= 0 { let (rem, extension) = parse_extension(rest)?; rest = rem; - extension_length -= extension.get_length(); + extension_length -= i32::try_from(extension.get_length()).unwrap(); // Todo:: Proper error - if extension_vec.push(extension).is_err() || extension_length < 0 { + if extension_length < 0 { todo!() } } diff --git a/src/tls.rs b/src/tls.rs index 15c0dc6..af47d38 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -10,19 +10,22 @@ use smoltcp::wire::Ipv4Address; use smoltcp::wire::IpEndpoint; use smoltcp::Result; use smoltcp::Error; +use smoltcp::iface::EthernetInterface; +use smoltcp::time::Instant; +use smoltcp::phy::Device; use byteorder::{ByteOrder, NetworkEndian, BigEndian}; -use heapless::Vec; -use heapless::consts::*; - use core::convert::TryInto; use core::convert::TryFrom; use rand_core::{RngCore, CryptoRng}; use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret}; +use alloc::vec::{ self, Vec }; + use crate::tls_packet::*; +use crate::parse::parse_tls_repr; #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[allow(non_camel_case_types)] @@ -81,7 +84,15 @@ impl TlsSocket { } } - pub fn tls_connect(&mut self, sockets: &mut SocketSet) -> Result { + pub fn tls_connect( + &mut self, + iface: EthernetInterface, + sockets: &mut SocketSet, + now: Instant + ) -> Result + where + DeviceT: for<'d> Device<'d> + { // Check tcp_socket connectivity { let mut tcp_socket = sockets.get::(self.tcp_handle); @@ -183,13 +194,13 @@ impl TlsSocket { compression_method_length: 1, compression_methods: 0, extension_length: supported_versions_extension.get_length(), - extensions: &[ + extensions: vec![ supported_versions_extension, signature_algorithms_extension, supported_groups_extension, psk_key_exchange_modes_extension, - key_share_extension, - ], + key_share_extension + ] }; client_hello.extension_length = { @@ -244,10 +255,25 @@ impl TlsSocket { } // Generic inner recv method, through TCP socket - fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { + fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result> { let mut tcp_socket = sockets.get::(self.tcp_handle); - let size = tcp_socket.recv_slice(byte_array)?; - todo!() + tcp_socket.recv_slice(byte_array)?; + let mut vec: Vec = Vec::new(); + + let mut bytes: &[u8] = byte_array; + loop { + match parse_tls_repr(bytes) { + Ok((rest, repr)) => { + vec.push(repr); + if rest.len() == 0 { + return Ok(vec); + } else { + bytes = rest; + } + }, + _ => return Err(Error::Unrecognized), + }; + } } } @@ -317,7 +343,7 @@ impl<'a> TlsBuffer<'a> { Ok(slice) } - fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr) -> Result<()> { + fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr<'a>) -> Result<()> { self.write_u8(tls_repr.content_type.into())?; self.write_u16(tls_repr.version.into())?; self.write_u16(tls_repr.length)?; @@ -332,13 +358,13 @@ impl<'a> TlsBuffer<'a> { Ok(()) } - fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr) -> Result<()> { + fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr<'a>) -> Result<()> { self.write_u8(handshake_repr.msg_type.into())?; self.write_u24(handshake_repr.length)?; self.enqueue_handshake_data(handshake_repr.handshake_data) } - fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData) -> Result<()> { + fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData<'a>) -> Result<()> { match handshake_data { HandshakeData::ClientHello(client_hello) => { self.enqueue_client_hello(client_hello) @@ -349,7 +375,7 @@ impl<'a> TlsBuffer<'a> { } } - fn enqueue_client_hello(&mut self, client_hello: ClientHello) -> Result<()> { + fn enqueue_client_hello(&mut self, client_hello: ClientHello<'a>) -> Result<()> { self.write_u16(client_hello.version.into())?; self.write(&client_hello.random)?; self.write_u8(client_hello.session_id_length)?; @@ -364,7 +390,7 @@ impl<'a> TlsBuffer<'a> { self.enqueue_extensions(client_hello.extensions) } - fn enqueue_extensions(&mut self, extensions: &[Extension]) -> Result<()> { + fn enqueue_extensions(&mut self, extensions: Vec>) -> Result<()> { for extension in extensions { self.write_u16(extension.extension_type.into())?; self.write_u16(extension.length)?; diff --git a/src/tls_packet.rs b/src/tls_packet.rs index 4ec1bca..26d62ed 100644 --- a/src/tls_packet.rs +++ b/src/tls_packet.rs @@ -1,9 +1,15 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use num_enum::IntoPrimitive; use num_enum::TryFromPrimitive; + +use rand_core::RngCore; +use rand_core::CryptoRng; + use core::convert::TryFrom; use core::convert::TryInto; +use alloc::vec::Vec; + #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u8)] pub(crate) enum TlsContentType { @@ -26,7 +32,7 @@ pub(crate) enum TlsVersion { Tls13 = 0x0304, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) struct TlsRepr<'a> { pub(crate) content_type: TlsContentType, pub(crate) version: TlsVersion, @@ -35,6 +41,25 @@ pub(crate) struct TlsRepr<'a> { pub(crate) handshake: Option> } +impl<'a> TlsRepr<'a> { + pub(crate) fn new() -> Self { + TlsRepr { + content_type: TlsContentType::Invalid, + version: TlsVersion::Tls12, + length: 0, + payload: None, + handshake: None, + } + } + + pub(crate) fn client_hello(mut self) -> Self { + self.content_type = TlsContentType::Handshake; + self.version = TlsVersion::Tls10; + // TODO: Fill in handshake field + self + } +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[repr(u8)] pub(crate) enum HandshakeType { @@ -53,7 +78,7 @@ pub(crate) enum HandshakeType { MessageHash = 254, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) struct HandshakeRepr<'a> { pub(crate) msg_type: HandshakeType, pub(crate) length: u32, @@ -61,6 +86,14 @@ pub(crate) struct HandshakeRepr<'a> { } impl<'a, 'b> HandshakeRepr<'a> { + pub(self) fn new() -> Self { + HandshakeRepr { + msg_type: HandshakeType::Unknown, + length: 0, + handshake_data: HandshakeData::Uninitialized, + } + } + pub(crate) fn get_length(&self) -> u16 { let mut length :u16 = 1; // Handshake Type length += 3; // Length of Handshake data @@ -70,6 +103,7 @@ impl<'a, 'b> HandshakeRepr<'a> { } #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] +#[allow(non_camel_case)] #[repr(u16)] pub(crate) enum CipherSuite { TLS_AES_128_GCM_SHA256 = 0x1301, @@ -79,7 +113,7 @@ pub(crate) enum CipherSuite { TLS_AES_128_CCM_8_SHA256 = 0x1305, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) struct ClientHello<'a> { pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303) pub(crate) random: [u8; 32], @@ -90,10 +124,10 @@ pub(crate) struct ClientHello<'a> { pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0 pub(crate) extension_length: u16, - pub(crate) extensions: &'a[Extension<'a>], + pub(crate) extensions: Vec>, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) enum HandshakeData<'a> { Uninitialized, ClientHello(ClientHello<'a>), @@ -111,6 +145,33 @@ impl<'a> HandshakeData<'a> { } impl<'a> ClientHello<'a> { + pub(self) fn new(rng: &mut T) -> Self + where + T: RngCore + CryptoRng + { + let mut client_hello = ClientHello { + version: TlsVersion::Tls12, + random: [0; 32], + session_id_length: 32, + session_id: [0; 32], + cipher_suites_length: 0, + cipher_suites: &[ + CipherSuite::TLS_AES_128_GCM_SHA256, + CipherSuite::TLS_AES_256_GCM_SHA384, + CipherSuite::TLS_CHACHA20_POLY1305_SHA256, + ], + compression_method_length: 1, + compression_methods: 0, + extension_length: 0, + extensions: Vec::new(), + }; + + rng.fill_bytes(&mut client_hello.random); + rng.fill_bytes(&mut client_hello.session_id); + + client_hello + } + pub(crate) fn get_length(&self) -> u32 { let mut length :u32 = 2; // TlsVersion size length += 32; // Random size @@ -128,7 +189,7 @@ impl<'a> ClientHello<'a> { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub(crate) struct ServerHello<'a> { pub(crate) version: TlsVersion, pub(crate) random: &'a[u8], @@ -137,7 +198,7 @@ pub(crate) struct ServerHello<'a> { pub(crate) cipher_suite: CipherSuite, pub(crate) compression_method: u8, // Always 0 pub(crate) extension_length: u16, - pub(crate) extensions: &'a[Extension<'a>], + pub(crate) extensions: Vec>, } #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]