tls: insert iface polling

This commit is contained in:
occheung 2020-10-11 23:41:02 +08:00
parent ba40c0780c
commit 70527da8c4
6 changed files with 161 additions and 37 deletions

View File

@ -1,7 +1,7 @@
[package] [package]
name = "smoltcp-tls" name = "smoltcp-tls"
version = "0.1.0" version = "0.1.0"
authors = ["occheung <occheung@connect.ust.hk>"] authors = ["occheung"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
@ -15,7 +15,7 @@ log = {version = "0.4.11"}
[dependencies.smoltcp] [dependencies.smoltcp]
version = "0.6.0" version = "0.6.0"
default-features = false default-features = false
features = ["proto-ipv4", "proto-ipv6", "socket-tcp"] features = ["ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp", "alloc"]
[dependencies.rand_core] [dependencies.rand_core]
version = "0.5.1" version = "0.5.1"

View File

@ -1,9 +1,14 @@
#![no_std] #![no_std]
#[macro_use]
extern crate alloc;
pub mod tls; pub mod tls;
pub mod tls_packet; pub mod tls_packet;
pub mod parse; pub mod parse;
// TODO: Implement errors
// Details: Encapsulate smoltcp & nom errors
pub enum Error { pub enum Error {
PropagatedError(smoltcp::Error), PropagatedError(smoltcp::Error),
ParsingError() ParsingError()

View File

@ -3,6 +3,36 @@ use smoltcp::socket::TcpSocketBuffer;
use smoltcp::socket::SocketSet; use smoltcp::socket::SocketSet;
use smoltcp::wire::Ipv4Address; use smoltcp::wire::Ipv4Address;
use rand_core::RngCore;
use rand_core::CryptoRng;
use rand_core::impls;
use rand_core::Error;
struct CountingRng(u64);
impl RngCore for CountingRng {
fn next_u32(&mut self) -> u32 {
self.next_u64() as u32
}
fn next_u64(&mut self) -> u64 {
self.0 += 1;
self.0
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
impls::fill_bytes_via_next(self, dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
Ok(self.fill_bytes(dest))
}
}
impl CryptoRng for CountingRng {}
static mut RNG: CountingRng = CountingRng(0);
fn main() { fn main() {
let mut socket_set_entries: [_; 8] = Default::default(); let mut socket_set_entries: [_; 8] = Default::default();
let mut sockets = SocketSet::new(&mut socket_set_entries[..]); let mut sockets = SocketSet::new(&mut socket_set_entries[..]);
@ -10,13 +40,14 @@ fn main() {
let mut tx_storage = [0; 4096]; let mut tx_storage = [0; 4096];
let mut rx_storage = [0; 4096]; let mut rx_storage = [0; 4096];
let mut tls_socket = { let mut tls_socket = unsafe {
let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]); let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]);
let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]); let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]);
TlsSocket::new( TlsSocket::new(
&mut sockets, &mut sockets,
rx_buffer, rx_buffer,
tx_buffer tx_buffer,
&mut RNG,
) )
}; };

View File

@ -11,9 +11,9 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use crate::tls_packet::*; use crate::tls_packet::*;
use core::convert::TryFrom; use core::convert::TryFrom;
use heapless::{ Vec, consts::* }; use alloc::vec::Vec;
fn parse_tls(bytes: &[u8]) -> IResult<&[u8], TlsRepr> { pub(crate) fn parse_tls_repr(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
let content_type = take(1_usize); let content_type = take(1_usize);
let version = take(2_usize); let version = take(2_usize);
let length = take(2_usize); let length = take(2_usize);
@ -65,7 +65,9 @@ fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> {
use crate::tls_packet::HandshakeType::*; use crate::tls_packet::HandshakeType::*;
match repr.msg_type { match repr.msg_type {
ServerHello => { ServerHello => {
todo!() let (rest, data) = parse_server_hello(bytes)?;
repr.handshake_data = data;
Ok((rest, repr))
}, },
_ => todo!() _ => todo!()
} }
@ -90,27 +92,26 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> {
let (mut rest, (cipher_suite, compression_method, extension_length)) = let (mut rest, (cipher_suite, compression_method, extension_length)) =
tuple((cipher_suite, compression_method, extension_length))(rest)?; tuple((cipher_suite, compression_method, extension_length))(rest)?;
let mut extension_length = NetworkEndian::read_u16(extension_length);
let mut server_hello = ServerHello { let mut server_hello = ServerHello {
version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(), version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(),
random, random,
session_id_echo_length: session_id_echo_length, session_id_echo_length,
session_id_echo, session_id_echo,
cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(), cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(),
compression_method: compression_method[0], compression_method: compression_method[0],
extension_length, extension_length: NetworkEndian::read_u16(extension_length),
extensions: &[] extensions: Vec::new(),
}; };
let mut extension_vec: Vec<Extension, U32> = Vec::new(); let mut extension_vec: Vec<Extension> = Vec::new();
let mut extension_length: i32 = server_hello.extension_length.into();
while extension_length >= 0 { while extension_length >= 0 {
let (rem, extension) = parse_extension(rest)?; let (rem, extension) = parse_extension(rest)?;
rest = rem; rest = rem;
extension_length -= extension.get_length(); extension_length -= i32::try_from(extension.get_length()).unwrap();
// Todo:: Proper error // Todo:: Proper error
if extension_vec.push(extension).is_err() || extension_length < 0 { if extension_length < 0 {
todo!() todo!()
} }
} }

View File

@ -10,19 +10,22 @@ use smoltcp::wire::Ipv4Address;
use smoltcp::wire::IpEndpoint; use smoltcp::wire::IpEndpoint;
use smoltcp::Result; use smoltcp::Result;
use smoltcp::Error; use smoltcp::Error;
use smoltcp::iface::EthernetInterface;
use smoltcp::time::Instant;
use smoltcp::phy::Device;
use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use heapless::Vec;
use heapless::consts::*;
use core::convert::TryInto; use core::convert::TryInto;
use core::convert::TryFrom; use core::convert::TryFrom;
use rand_core::{RngCore, CryptoRng}; use rand_core::{RngCore, CryptoRng};
use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret}; use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret};
use alloc::vec::{ self, Vec };
use crate::tls_packet::*; use crate::tls_packet::*;
use crate::parse::parse_tls_repr;
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
@ -81,7 +84,15 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
} }
} }
pub fn tls_connect(&mut self, sockets: &mut SocketSet) -> Result<bool> { pub fn tls_connect<DeviceT>(
&mut self,
iface: EthernetInterface<DeviceT>,
sockets: &mut SocketSet,
now: Instant
) -> Result<bool>
where
DeviceT: for<'d> Device<'d>
{
// Check tcp_socket connectivity // Check tcp_socket connectivity
{ {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
@ -183,13 +194,13 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
compression_method_length: 1, compression_method_length: 1,
compression_methods: 0, compression_methods: 0,
extension_length: supported_versions_extension.get_length(), extension_length: supported_versions_extension.get_length(),
extensions: &[ extensions: vec![
supported_versions_extension, supported_versions_extension,
signature_algorithms_extension, signature_algorithms_extension,
supported_groups_extension, supported_groups_extension,
psk_key_exchange_modes_extension, psk_key_exchange_modes_extension,
key_share_extension, key_share_extension
], ]
}; };
client_hello.extension_length = { client_hello.extension_length = {
@ -244,10 +255,25 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
} }
// Generic inner recv method, through TCP socket // Generic inner recv method, through TCP socket
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<TlsRepr<'a>> { fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<Vec::<TlsRepr>> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle); let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let size = tcp_socket.recv_slice(byte_array)?; tcp_socket.recv_slice(byte_array)?;
todo!() let mut vec: Vec<TlsRepr> = Vec::new();
let mut bytes: &[u8] = byte_array;
loop {
match parse_tls_repr(bytes) {
Ok((rest, repr)) => {
vec.push(repr);
if rest.len() == 0 {
return Ok(vec);
} else {
bytes = rest;
}
},
_ => return Err(Error::Unrecognized),
};
}
} }
} }
@ -317,7 +343,7 @@ impl<'a> TlsBuffer<'a> {
Ok(slice) Ok(slice)
} }
fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr) -> Result<()> { fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr<'a>) -> Result<()> {
self.write_u8(tls_repr.content_type.into())?; self.write_u8(tls_repr.content_type.into())?;
self.write_u16(tls_repr.version.into())?; self.write_u16(tls_repr.version.into())?;
self.write_u16(tls_repr.length)?; self.write_u16(tls_repr.length)?;
@ -332,13 +358,13 @@ impl<'a> TlsBuffer<'a> {
Ok(()) Ok(())
} }
fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr) -> Result<()> { fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr<'a>) -> Result<()> {
self.write_u8(handshake_repr.msg_type.into())?; self.write_u8(handshake_repr.msg_type.into())?;
self.write_u24(handshake_repr.length)?; self.write_u24(handshake_repr.length)?;
self.enqueue_handshake_data(handshake_repr.handshake_data) self.enqueue_handshake_data(handshake_repr.handshake_data)
} }
fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData) -> Result<()> { fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData<'a>) -> Result<()> {
match handshake_data { match handshake_data {
HandshakeData::ClientHello(client_hello) => { HandshakeData::ClientHello(client_hello) => {
self.enqueue_client_hello(client_hello) self.enqueue_client_hello(client_hello)
@ -349,7 +375,7 @@ impl<'a> TlsBuffer<'a> {
} }
} }
fn enqueue_client_hello(&mut self, client_hello: ClientHello) -> Result<()> { fn enqueue_client_hello(&mut self, client_hello: ClientHello<'a>) -> Result<()> {
self.write_u16(client_hello.version.into())?; self.write_u16(client_hello.version.into())?;
self.write(&client_hello.random)?; self.write(&client_hello.random)?;
self.write_u8(client_hello.session_id_length)?; self.write_u8(client_hello.session_id_length)?;
@ -364,7 +390,7 @@ impl<'a> TlsBuffer<'a> {
self.enqueue_extensions(client_hello.extensions) self.enqueue_extensions(client_hello.extensions)
} }
fn enqueue_extensions(&mut self, extensions: &[Extension]) -> Result<()> { fn enqueue_extensions(&mut self, extensions: Vec<Extension<'a>>) -> Result<()> {
for extension in extensions { for extension in extensions {
self.write_u16(extension.extension_type.into())?; self.write_u16(extension.extension_type.into())?;
self.write_u16(extension.length)?; self.write_u16(extension.length)?;

View File

@ -1,9 +1,15 @@
use byteorder::{ByteOrder, NetworkEndian, BigEndian}; use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use num_enum::IntoPrimitive; use num_enum::IntoPrimitive;
use num_enum::TryFromPrimitive; use num_enum::TryFromPrimitive;
use rand_core::RngCore;
use rand_core::CryptoRng;
use core::convert::TryFrom; use core::convert::TryFrom;
use core::convert::TryInto; use core::convert::TryInto;
use alloc::vec::Vec;
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)] #[repr(u8)]
pub(crate) enum TlsContentType { pub(crate) enum TlsContentType {
@ -26,7 +32,7 @@ pub(crate) enum TlsVersion {
Tls13 = 0x0304, Tls13 = 0x0304,
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub(crate) struct TlsRepr<'a> { pub(crate) struct TlsRepr<'a> {
pub(crate) content_type: TlsContentType, pub(crate) content_type: TlsContentType,
pub(crate) version: TlsVersion, pub(crate) version: TlsVersion,
@ -35,6 +41,25 @@ pub(crate) struct TlsRepr<'a> {
pub(crate) handshake: Option<HandshakeRepr<'a>> pub(crate) handshake: Option<HandshakeRepr<'a>>
} }
impl<'a> TlsRepr<'a> {
pub(crate) fn new() -> Self {
TlsRepr {
content_type: TlsContentType::Invalid,
version: TlsVersion::Tls12,
length: 0,
payload: None,
handshake: None,
}
}
pub(crate) fn client_hello(mut self) -> Self {
self.content_type = TlsContentType::Handshake;
self.version = TlsVersion::Tls10;
// TODO: Fill in handshake field
self
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)] #[repr(u8)]
pub(crate) enum HandshakeType { pub(crate) enum HandshakeType {
@ -53,7 +78,7 @@ pub(crate) enum HandshakeType {
MessageHash = 254, MessageHash = 254,
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub(crate) struct HandshakeRepr<'a> { pub(crate) struct HandshakeRepr<'a> {
pub(crate) msg_type: HandshakeType, pub(crate) msg_type: HandshakeType,
pub(crate) length: u32, pub(crate) length: u32,
@ -61,6 +86,14 @@ pub(crate) struct HandshakeRepr<'a> {
} }
impl<'a, 'b> HandshakeRepr<'a> { impl<'a, 'b> HandshakeRepr<'a> {
pub(self) fn new() -> Self {
HandshakeRepr {
msg_type: HandshakeType::Unknown,
length: 0,
handshake_data: HandshakeData::Uninitialized,
}
}
pub(crate) fn get_length(&self) -> u16 { pub(crate) fn get_length(&self) -> u16 {
let mut length :u16 = 1; // Handshake Type let mut length :u16 = 1; // Handshake Type
length += 3; // Length of Handshake data length += 3; // Length of Handshake data
@ -70,6 +103,7 @@ impl<'a, 'b> HandshakeRepr<'a> {
} }
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[allow(non_camel_case)]
#[repr(u16)] #[repr(u16)]
pub(crate) enum CipherSuite { pub(crate) enum CipherSuite {
TLS_AES_128_GCM_SHA256 = 0x1301, TLS_AES_128_GCM_SHA256 = 0x1301,
@ -79,7 +113,7 @@ pub(crate) enum CipherSuite {
TLS_AES_128_CCM_8_SHA256 = 0x1305, TLS_AES_128_CCM_8_SHA256 = 0x1305,
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub(crate) struct ClientHello<'a> { pub(crate) struct ClientHello<'a> {
pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303) pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303)
pub(crate) random: [u8; 32], pub(crate) random: [u8; 32],
@ -90,10 +124,10 @@ pub(crate) struct ClientHello<'a> {
pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte
pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0 pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0
pub(crate) extension_length: u16, pub(crate) extension_length: u16,
pub(crate) extensions: &'a[Extension<'a>], pub(crate) extensions: Vec<Extension<'a>>,
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub(crate) enum HandshakeData<'a> { pub(crate) enum HandshakeData<'a> {
Uninitialized, Uninitialized,
ClientHello(ClientHello<'a>), ClientHello(ClientHello<'a>),
@ -111,6 +145,33 @@ impl<'a> HandshakeData<'a> {
} }
impl<'a> ClientHello<'a> { impl<'a> ClientHello<'a> {
pub(self) fn new<T>(rng: &mut T) -> Self
where
T: RngCore + CryptoRng
{
let mut client_hello = ClientHello {
version: TlsVersion::Tls12,
random: [0; 32],
session_id_length: 32,
session_id: [0; 32],
cipher_suites_length: 0,
cipher_suites: &[
CipherSuite::TLS_AES_128_GCM_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384,
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
],
compression_method_length: 1,
compression_methods: 0,
extension_length: 0,
extensions: Vec::new(),
};
rng.fill_bytes(&mut client_hello.random);
rng.fill_bytes(&mut client_hello.session_id);
client_hello
}
pub(crate) fn get_length(&self) -> u32 { pub(crate) fn get_length(&self) -> u32 {
let mut length :u32 = 2; // TlsVersion size let mut length :u32 = 2; // TlsVersion size
length += 32; // Random size length += 32; // Random size
@ -128,7 +189,7 @@ impl<'a> ClientHello<'a> {
} }
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub(crate) struct ServerHello<'a> { pub(crate) struct ServerHello<'a> {
pub(crate) version: TlsVersion, pub(crate) version: TlsVersion,
pub(crate) random: &'a[u8], pub(crate) random: &'a[u8],
@ -137,7 +198,7 @@ pub(crate) struct ServerHello<'a> {
pub(crate) cipher_suite: CipherSuite, pub(crate) cipher_suite: CipherSuite,
pub(crate) compression_method: u8, // Always 0 pub(crate) compression_method: u8, // Always 0
pub(crate) extension_length: u16, pub(crate) extension_length: u16,
pub(crate) extensions: &'a[Extension<'a>], pub(crate) extensions: Vec<Extension<'a>>,
} }
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)] #[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]