parser: init

This commit is contained in:
occheung 2020-10-11 13:46:24 +08:00
parent bb70038c7c
commit ba40c0780c
5 changed files with 315 additions and 169 deletions

View File

@ -4,30 +4,40 @@ version = "0.1.0"
authors = ["occheung <occheung@connect.ust.hk>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
aes-gcm = "0.7.0"
chacha20poly1305 = "0.6.0"
byteorder = "1.3.4"
num_enum = "0.5.1"
sha2 = { version = "0.9.1", default-features = false }
byteorder = { version = "1.3.4", default-features = false }
num_enum = { version = "0.5.1", default-features = false }
log = {version = "0.4.11"}
[dependencies.smoltcp]
version = "0.6.0"
default-features = false
features = ["proto-ipv4", "proto-ipv6", "socket-tcp"]
[dependencies.rand_chacha]
version = "0.2.2"
[dependencies.rand_core]
version = "0.5.1"
default-features = false
features = []
[dependencies.rand]
version = "0.7.3"
[dependencies.p256]
version = "0.5.0"
default-features = false
features = ["getrandom"]
features = [ "ecdh", "ecdsa", "arithmetic" ]
[dependencies.rsa]
git = "https://github.com/RustCrypto/RSA.git"
default-features = false
features = [ "alloc" ]
[dependencies.heapless]
version = "0.5.6"
default-features = false
features = []
features = []
[dependencies.nom]
version = "5.1.2"
default-features = false
features= [ "regex", "lexical" ]

View File

@ -1,4 +1,10 @@
#![no_std]
pub mod tls;
pub mod tls_packet;
pub mod tls_packet;
pub mod parse;
pub enum Error {
PropagatedError(smoltcp::Error),
ParsingError()
}

141
src/parse.rs Normal file
View File

@ -0,0 +1,141 @@
use nom::IResult;
use nom::bytes::complete::take;
use nom::combinator::complete;
use nom::sequence::tuple;
use nom::error::ErrorKind;
use smoltcp::Error;
use smoltcp::Result;
use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use crate::tls_packet::*;
use core::convert::TryFrom;
use heapless::{ Vec, consts::* };
fn parse_tls(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
let content_type = take(1_usize);
let version = take(2_usize);
let length = take(2_usize);
let (rest, (content_type, version, length)) =
tuple((content_type, version, length))(bytes)?;
let mut repr = TlsRepr {
content_type: TlsContentType::try_from(content_type[0])
.unwrap(),
version: TlsVersion::try_from(NetworkEndian::read_u16(version))
.unwrap(),
length: NetworkEndian::read_u16(length),
payload: None,
handshake: None,
};
{
use crate::tls_packet::TlsContentType::*;
match repr.content_type {
Handshake => {
let (rest, handshake) = parse_handshake(rest)?;
repr.handshake = Some(handshake);
Ok((rest, repr))
},
_ => {
let (rest, payload) = take(repr.length)(rest)?;
repr.payload = Some(payload);
Ok((rest, repr))
}
}
}
}
fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> {
let handshake_type = take(1_usize);
let length = take(3_usize);
let (rest, (handshake_type, length)) =
tuple((handshake_type, length))(bytes)?;
let mut repr = HandshakeRepr {
msg_type: HandshakeType::try_from(handshake_type[0]).unwrap(),
length: NetworkEndian::read_u24(length),
handshake_data: HandshakeData::Uninitialized,
};
{
use crate::tls_packet::HandshakeType::*;
match repr.msg_type {
ServerHello => {
todo!()
},
_ => todo!()
}
}
}
fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> {
let version = take(2_usize);
let random = take(32_usize);
let session_id_echo_length = take(1_usize);
let (rest, (version, random, session_id_echo_length)) =
tuple((version, random, session_id_echo_length))(bytes)?;
let session_id_echo_length = session_id_echo_length[0];
let (rest, session_id_echo) = take(session_id_echo_length)(rest)?;
let cipher_suite = take(2_usize);
let compression_method = take(1_usize);
let extension_length = take(2_usize);
let (mut rest, (cipher_suite, compression_method, extension_length)) =
tuple((cipher_suite, compression_method, extension_length))(rest)?;
let mut extension_length = NetworkEndian::read_u16(extension_length);
let mut server_hello = ServerHello {
version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(),
random,
session_id_echo_length: session_id_echo_length,
session_id_echo,
cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(),
compression_method: compression_method[0],
extension_length,
extensions: &[]
};
let mut extension_vec: Vec<Extension, U32> = Vec::new();
while extension_length >= 0 {
let (rem, extension) = parse_extension(rest)?;
rest = rem;
extension_length -= extension.get_length();
// Todo:: Proper error
if extension_vec.push(extension).is_err() || extension_length < 0 {
todo!()
}
}
server_hello.extensions = extension_vec;
Ok((rest, HandshakeData::ServerHello(server_hello)))
}
fn parse_extension(bytes: &[u8]) -> IResult<&[u8], Extension> {
let extension_type = take(2_usize);
let length = take(2_usize);
let (rest, (extension_type, length)) =
tuple((extension_type, length))(bytes)?;
let length = NetworkEndian::read_u16(length);
let (rest, extension_data) = take(length)(rest)?;
Ok((
rest,
Extension {
extension_type: ExtensionType::try_from(NetworkEndian::read_u16(extension_type)).unwrap(),
length,
extension_data
}
))
}

View File

@ -11,10 +11,7 @@ use smoltcp::wire::IpEndpoint;
use smoltcp::Result;
use smoltcp::Error;
use byteorder::{ByteOrder, NetworkEndian, BigEndian, WriteBytesExt};
use rand::prelude::*;
use rand_chacha::ChaCha20Rng;
use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use heapless::Vec;
use heapless::consts::*;
@ -22,7 +19,11 @@ use heapless::consts::*;
use core::convert::TryInto;
use core::convert::TryFrom;
use rand_core::{RngCore, CryptoRng};
use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret};
use crate::tls_packet::*;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[allow(non_camel_case_types)]
enum TlsState {
@ -36,18 +37,19 @@ enum TlsState {
CONNECTED,
}
pub struct TlsSocket
pub struct TlsSocket<R: 'static + RngCore + CryptoRng>
{
state: TlsState,
tcp_handle: SocketHandle,
random: ChaCha20Rng,
rng: R,
}
impl TlsSocket {
impl<R: RngCore + CryptoRng> TlsSocket<R> {
pub fn new<'a, 'b, 'c>(
sockets: &mut SocketSet<'a, 'b, 'c>,
rx_buffer: TcpSocketBuffer<'b>,
tx_buffer: TcpSocketBuffer<'b>,
rng: R,
) -> Self
where
'b: 'c,
@ -57,7 +59,7 @@ impl TlsSocket {
TlsSocket {
state: TlsState::START,
tcp_handle,
random: ChaCha20Rng::from_entropy(),
rng,
}
}
@ -72,13 +74,18 @@ impl TlsSocket {
U: Into<IpEndpoint>,
{
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.connect(remote_endpoint, local_endpoint)
if tcp_socket.state() == TcpState::Established {
Ok(())
} else {
tcp_socket.connect(remote_endpoint, local_endpoint)
}
}
pub fn tls_connect(&mut self, sockets: &mut SocketSet) -> Result<bool> {
// Check tcp_socket connectivity
{
let tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
tcp_socket.set_keep_alive(Some(smoltcp::time::Duration::from_millis(1000)));
if tcp_socket.state() != TcpState::Established {
return Ok(false);
}
@ -87,11 +94,11 @@ impl TlsSocket {
if self.state == TlsState::START {
// Create TLS representation, length and payload not finalised
let mut random: [u8; 32] = [0; 32];
self.rng.fill_bytes(&mut random);
let mut session_id: [u8; 32] = [0; 32];
self.random.fill_bytes(&mut random);
self.random.fill_bytes(&mut session_id);
self.rng.fill_bytes(&mut session_id);
let cipher_suites_length = 3;
let cipher_suites_length = 6;
let cipher_suites = [
CipherSuite::TLS_AES_128_GCM_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384,
@ -101,15 +108,72 @@ impl TlsSocket {
// Length: to be determined
let supported_versions_extension = Extension {
extension_type: ExtensionType::SupportedVersions,
length: 3,
length: 5,
extension_data: &[
2, // Number of supported versions * 2
4, // Number of supported versions * 2
// Need 2 bytes to contain a version
0x03, 0x04 // 0x0303: TLS Version 1.3
0x03, 0x04, // 0x0304: TLS Version 1.3
0x03, 0x03, // 0x0303: TLS version 1.2
]
};
let client_hello = ClientHello {
let signature_algorithms_extension = Extension {
extension_type: ExtensionType::SignatureAlgorithms,
length: 24,
extension_data: &[
0x00, 22, // Length in bytes
0x04, 0x03, // ecdsa_secp256r1_sha256
0x08, 0x07, // ed25519
0x08, 0x09, // rsa_pss_pss_sha256
0x04, 0x01, // rsa_pkcs1_sha256
0x08, 0x04, // rsa_pss_rsae_sha256
0x08, 0x0a, // rsa_pss_pss_sha384
0x05, 0x01, // rsa_pkcs1_sha384
0x08, 0x05, // rsa_pss_rsae_sha384
0x08, 0x0b, // rsa_pss_pss_sha512
0x06, 0x01, // rsa_pkcs1_sha512
0x08, 0x06, // rsa_pss_rsae_sha512
]
};
let supported_groups_extension = Extension {
extension_type: ExtensionType::SupportedGroups,
length: 4,
extension_data: &[
0x00, 0x02, // Length in bytes
0x00, 0x17, // secp256r1
]
};
let key_share_extension = Extension {
extension_type: ExtensionType::KeyShare,
length: 71,
extension_data: &{
let ecdh_secret = unsafe { EphemeralSecret::random(&mut self.rng) };
let ecdh_public = EncodedPoint::from(&ecdh_secret);
let x_coor = ecdh_public.x();
let y_coor = ecdh_public.y().unwrap();
let mut data: [u8; 71] = [0; 71];
data[0..2].copy_from_slice(&[0x00, 69]); // Length in bytes
data[2..4].copy_from_slice(&[0x00, 0x17]); // secp256r1
data[4..6].copy_from_slice(&[0x00, 65]); // key exchange length
data[6..7].copy_from_slice(&[0x04]); // Fixed legacy value
data[7..39].copy_from_slice(&x_coor);
data[39..71].copy_from_slice(&y_coor);
data
}
};
let psk_key_exchange_modes_extension = Extension {
extension_type: ExtensionType::PSKKeyExchangeModes,
length: 2,
extension_data: &[
0x01, // Length in bytes
0x01, // psk_dhe_ke
]
};
let mut client_hello = ClientHello {
version: TlsVersion::Tls12,
random,
session_id_length: 32,
@ -119,7 +183,21 @@ impl TlsSocket {
compression_method_length: 1,
compression_methods: 0,
extension_length: supported_versions_extension.get_length(),
extensions: &[supported_versions_extension],
extensions: &[
supported_versions_extension,
signature_algorithms_extension,
supported_groups_extension,
psk_key_exchange_modes_extension,
key_share_extension,
],
};
client_hello.extension_length = {
let mut sum = 0;
for ext in client_hello.extensions.iter() {
sum += ext.get_length();
}
sum
};
let handshake_repr = HandshakeRepr {
@ -130,12 +208,14 @@ impl TlsSocket {
let repr = TlsRepr {
content_type: TlsContentType::Handshake,
version: TlsVersion::Tls13,
length: 0,
version: TlsVersion::Tls10,
length: handshake_repr.get_length(),
payload: None,
handshake: Some(handshake_repr),
};
log::info!("{:?}", repr);
self.send_tls_repr(sockets, repr)?;
self.state = TlsState::WAIT_SH;
Ok(true)
@ -164,11 +244,10 @@ impl TlsSocket {
}
// Generic inner recv method, through TCP socket
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<TlsRepr<'a, '_>> {
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<TlsRepr<'a>> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let size = tcp_socket.recv_slice(byte_array)?;
let buffer = TlsBuffer::new(&mut byte_array[..size]);
buffer.dequeue_tls_repr()
todo!()
}
}
@ -293,121 +372,6 @@ impl<'a> TlsBuffer<'a> {
}
Ok(())
}
fn dequeue_tls_repr<'b>(mut self) -> Result<TlsRepr<'a, 'b>> {
// Create a TLS Representation layer
// Modify the representation along the way
let mut repr = TlsRepr {
content_type: TlsContentType::Invalid,
version: TlsVersion::Tls10,
length: 0,
payload: None,
handshake: None,
};
repr.content_type = TlsContentType::try_from(self.read_u8()?)
.map_err(|_| Error::Unrecognized)?;
repr.version = TlsVersion::try_from(self.read_u16()?)
.map_err(|_| Error::Unrecognized)?;
repr.length = self.read_u16()?;
use TlsContentType::*;
match repr.content_type {
Invalid => Err(Error::Unrecognized),
ChangeCipherSpec | Alert => unimplemented!(),
Handshake => todo!(),
ApplicationData => {
repr.payload = Some(self.read_all());
Ok(repr)
}
}
}
fn dequeue_handshake<'b>(mut self) -> Result<HandshakeRepr<'a, 'b>> {
// Create a Handshake header representation
// Fill in proper value afterwards
let mut repr = HandshakeRepr {
msg_type: HandshakeType::ClientHello,
length: 0,
handshake_data: HandshakeData::Uninitialized,
};
repr.msg_type = HandshakeType::try_from(self.read_u8()?)
.map_err(|_| Error::Unrecognized)?;
repr.length = self.read_u24()?;
use HandshakeType::*;
match repr.msg_type {
ClientHello => unimplemented!(),
ServerHello => todo!(),
_ => unimplemented!(),
}
}
fn dequeue_server_hello(mut self) -> Result<ServerHello<'static, 'static>> {
// Create a Server Hello representation
// Fill in proper value afterwards
let mut server_hello = ServerHello {
version: TlsVersion::Tls10,
random: [0; 32],
session_id_echo_length: 0,
session_id_echo: [0; 32],
cipher_suite: CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
compression_method: 0,
extension_length: 0,
extensions: &[],
};
server_hello.version = TlsVersion::try_from(self.read_u16()?)
.map_err(|_| Error::Unrecognized)?;
for random_byte in &mut server_hello.random[..] {
*random_byte = self.read_u8()?;
}
server_hello.session_id_echo_length = self.read_u8()?;
for id_byte in &mut server_hello.session_id_echo[
..usize::try_from(server_hello.session_id_echo_length)
.map_err(|_| Error::Exhausted)?
] {
*id_byte = self.read_u8()?;
}
server_hello.cipher_suite = CipherSuite::try_from(self.read_u16()?)
.map_err(|_| Error::Unrecognized)?;
server_hello.compression_method = self.read_u8()?;
server_hello.extension_length = self.read_u16()?;
let mut remaining_length = server_hello.extension_length;
let mut extension_counter = 0;
let mut extension_vec: Vec<Extension, U32> = Vec::new();
while remaining_length != 0 {
extension_vec.push(self.dequeue_extension()?.clone())
.map_err(|_| Error::Exhausted)?;
// Deduct base length of an extension (ext_type, len)
remaining_length -= 4;
remaining_length -= extension_vec[extension_counter].length;
extension_counter += 1;
}
Ok(server_hello)
}
fn dequeue_extension(&self) -> Result<Extension<'_>> {
// Create an Extension representation
// Fill in proper value afterwards
let mut extension = Extension {
extension_type: ExtensionType::ServerName,
length: 0,
extension_data: &[],
};
extension.extension_type = ExtensionType::try_from(self.read_u16()?)
.map_err(|_| Error::Unrecognized)?;
extension.length = self.read_u16()?;
extension.extension_data = self.read_slice(
usize::try_from(extension.length)
.map_err(|_| Error::Exhausted)?
)?;
Ok(extension)
}
}
macro_rules! export_byte_order_fn {

View File

@ -2,10 +2,12 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use num_enum::IntoPrimitive;
use num_enum::TryFromPrimitive;
use core::convert::TryFrom;
use core::convert::TryInto;
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
pub(crate) enum TlsContentType {
#[num_enum(default)]
Invalid = 0,
ChangeCipherSpec = 20,
Alert = 21,
@ -16,24 +18,28 @@ pub(crate) enum TlsContentType {
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u16)]
pub(crate) enum TlsVersion {
#[num_enum(default)]
Unknown = 0x0000,
Tls10 = 0x0301,
Tls11 = 0x0302,
Tls12 = 0x0303,
Tls13 = 0x0304,
}
#[derive(Clone, Copy)]
pub(crate) struct TlsRepr<'a, 'b> {
#[derive(Debug, Clone, Copy)]
pub(crate) struct TlsRepr<'a> {
pub(crate) content_type: TlsContentType,
pub(crate) version: TlsVersion,
pub(crate) length: u16,
pub(crate) payload: Option<&'a[u8]>,
pub(crate) handshake: Option<HandshakeRepr<'a, 'b>>
pub(crate) handshake: Option<HandshakeRepr<'a>>
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
pub(crate) enum HandshakeType {
#[num_enum(default)]
Unknown = 0,
ClientHello = 1,
ServerHello = 2,
NewSessionTicket = 4,
@ -47,11 +53,20 @@ pub(crate) enum HandshakeType {
MessageHash = 254,
}
#[derive(Clone, Copy)]
pub(crate) struct HandshakeRepr<'a, 'b> {
#[derive(Debug, Clone, Copy)]
pub(crate) struct HandshakeRepr<'a> {
pub(crate) msg_type: HandshakeType,
pub(crate) length: u32,
pub(crate) handshake_data: HandshakeData<'a, 'b>,
pub(crate) handshake_data: HandshakeData<'a>,
}
impl<'a, 'b> HandshakeRepr<'a> {
pub(crate) fn get_length(&self) -> u16 {
let mut length :u16 = 1; // Handshake Type
length += 3; // Length of Handshake data
length += u16::try_from(self.handshake_data.get_length()).unwrap();
length
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
@ -64,8 +79,8 @@ pub(crate) enum CipherSuite {
TLS_AES_128_CCM_8_SHA256 = 0x1305,
}
#[derive(Clone, Copy)]
pub(crate) struct ClientHello<'a, 'b> {
#[derive(Debug, Clone, Copy)]
pub(crate) struct ClientHello<'a> {
pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303)
pub(crate) random: [u8; 32],
pub(crate) session_id_length: u8, // Legacy: Keep it 32
@ -75,17 +90,27 @@ pub(crate) struct ClientHello<'a, 'b> {
pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte
pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0
pub(crate) extension_length: u16,
pub(crate) extensions: &'a[Extension<'b>],
pub(crate) extensions: &'a[Extension<'a>],
}
#[derive(Clone, Copy)]
pub(crate) enum HandshakeData<'a, 'b> {
#[derive(Debug, Clone, Copy)]
pub(crate) enum HandshakeData<'a> {
Uninitialized,
ClientHello(ClientHello<'a, 'b>),
ServerHello(ServerHello<'a, 'b>),
ClientHello(ClientHello<'a>),
ServerHello(ServerHello<'a>),
}
impl<'a, 'b> ClientHello<'a, 'b> {
impl<'a> HandshakeData<'a> {
pub(crate) fn get_length(&self) -> u32 {
match self {
HandshakeData::ClientHello(data) => data.get_length(),
HandshakeData::ServerHello(data) => todo!(),
_ => 0,
}
}
}
impl<'a> ClientHello<'a> {
pub(crate) fn get_length(&self) -> u32 {
let mut length :u32 = 2; // TlsVersion size
length += 32; // Random size
@ -103,16 +128,16 @@ impl<'a, 'b> ClientHello<'a, 'b> {
}
}
#[derive(Clone, Copy)]
pub(crate) struct ServerHello<'a, 'b> {
#[derive(Debug, Clone, Copy)]
pub(crate) struct ServerHello<'a> {
pub(crate) version: TlsVersion,
pub(crate) random: [u8; 32],
pub(crate) random: &'a[u8],
pub(crate) session_id_echo_length: u8,
pub(crate) session_id_echo: [u8; 32],
pub(crate) session_id_echo: &'a[u8],
pub(crate) cipher_suite: CipherSuite,
pub(crate) compression_method: u8, // Always 0
pub(crate) extension_length: u16,
pub(crate) extensions: &'a[Extension<'b>],
pub(crate) extensions: &'a[Extension<'a>],
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
@ -148,7 +173,7 @@ impl ExtensionType {
}
}
#[derive(Clone, Copy)]
#[derive(Debug, Clone, Copy)]
pub(crate) struct Extension<'a> {
pub(crate) extension_type: ExtensionType,
pub(crate) length: u16,