tls: insert iface polling

This commit is contained in:
occheung 2020-10-11 23:41:02 +08:00
parent ba40c0780c
commit 70527da8c4
6 changed files with 161 additions and 37 deletions

View File

@ -1,7 +1,7 @@
[package]
name = "smoltcp-tls"
version = "0.1.0"
authors = ["occheung <occheung@connect.ust.hk>"]
authors = ["occheung"]
edition = "2018"
[dependencies]
@ -15,7 +15,7 @@ log = {version = "0.4.11"}
[dependencies.smoltcp]
version = "0.6.0"
default-features = false
features = ["proto-ipv4", "proto-ipv6", "socket-tcp"]
features = ["ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp", "alloc"]
[dependencies.rand_core]
version = "0.5.1"

View File

@ -1,9 +1,14 @@
#![no_std]
#[macro_use]
extern crate alloc;
pub mod tls;
pub mod tls_packet;
pub mod parse;
// TODO: Implement errors
// Details: Encapsulate smoltcp & nom errors
pub enum Error {
PropagatedError(smoltcp::Error),
ParsingError()

View File

@ -3,6 +3,36 @@ use smoltcp::socket::TcpSocketBuffer;
use smoltcp::socket::SocketSet;
use smoltcp::wire::Ipv4Address;
use rand_core::RngCore;
use rand_core::CryptoRng;
use rand_core::impls;
use rand_core::Error;
struct CountingRng(u64);
impl RngCore for CountingRng {
fn next_u32(&mut self) -> u32 {
self.next_u64() as u32
}
fn next_u64(&mut self) -> u64 {
self.0 += 1;
self.0
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
impls::fill_bytes_via_next(self, dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
Ok(self.fill_bytes(dest))
}
}
impl CryptoRng for CountingRng {}
static mut RNG: CountingRng = CountingRng(0);
fn main() {
let mut socket_set_entries: [_; 8] = Default::default();
let mut sockets = SocketSet::new(&mut socket_set_entries[..]);
@ -10,13 +40,14 @@ fn main() {
let mut tx_storage = [0; 4096];
let mut rx_storage = [0; 4096];
let mut tls_socket = {
let mut tls_socket = unsafe {
let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]);
let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]);
TlsSocket::new(
&mut sockets,
rx_buffer,
tx_buffer
tx_buffer,
&mut RNG,
)
};

View File

@ -11,9 +11,9 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use crate::tls_packet::*;
use core::convert::TryFrom;
use heapless::{ Vec, consts::* };
use alloc::vec::Vec;
fn parse_tls(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
pub(crate) fn parse_tls_repr(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
let content_type = take(1_usize);
let version = take(2_usize);
let length = take(2_usize);
@ -65,7 +65,9 @@ fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> {
use crate::tls_packet::HandshakeType::*;
match repr.msg_type {
ServerHello => {
todo!()
let (rest, data) = parse_server_hello(bytes)?;
repr.handshake_data = data;
Ok((rest, repr))
},
_ => todo!()
}
@ -90,27 +92,26 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> {
let (mut rest, (cipher_suite, compression_method, extension_length)) =
tuple((cipher_suite, compression_method, extension_length))(rest)?;
let mut extension_length = NetworkEndian::read_u16(extension_length);
let mut server_hello = ServerHello {
version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(),
random,
session_id_echo_length: session_id_echo_length,
session_id_echo_length,
session_id_echo,
cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(),
compression_method: compression_method[0],
extension_length,
extensions: &[]
extension_length: NetworkEndian::read_u16(extension_length),
extensions: Vec::new(),
};
let mut extension_vec: Vec<Extension, U32> = Vec::new();
let mut extension_vec: Vec<Extension> = Vec::new();
let mut extension_length: i32 = server_hello.extension_length.into();
while extension_length >= 0 {
let (rem, extension) = parse_extension(rest)?;
rest = rem;
extension_length -= extension.get_length();
extension_length -= i32::try_from(extension.get_length()).unwrap();
// Todo:: Proper error
if extension_vec.push(extension).is_err() || extension_length < 0 {
if extension_length < 0 {
todo!()
}
}

View File

@ -10,19 +10,22 @@ use smoltcp::wire::Ipv4Address;
use smoltcp::wire::IpEndpoint;
use smoltcp::Result;
use smoltcp::Error;
use smoltcp::iface::EthernetInterface;
use smoltcp::time::Instant;
use smoltcp::phy::Device;
use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use heapless::Vec;
use heapless::consts::*;
use core::convert::TryInto;
use core::convert::TryFrom;
use rand_core::{RngCore, CryptoRng};
use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret};
use alloc::vec::{ self, Vec };
use crate::tls_packet::*;
use crate::parse::parse_tls_repr;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[allow(non_camel_case_types)]
@ -81,7 +84,15 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
}
}
pub fn tls_connect(&mut self, sockets: &mut SocketSet) -> Result<bool> {
pub fn tls_connect<DeviceT>(
&mut self,
iface: EthernetInterface<DeviceT>,
sockets: &mut SocketSet,
now: Instant
) -> Result<bool>
where
DeviceT: for<'d> Device<'d>
{
// Check tcp_socket connectivity
{
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
@ -183,13 +194,13 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
compression_method_length: 1,
compression_methods: 0,
extension_length: supported_versions_extension.get_length(),
extensions: &[
extensions: vec![
supported_versions_extension,
signature_algorithms_extension,
supported_groups_extension,
psk_key_exchange_modes_extension,
key_share_extension,
],
key_share_extension
]
};
client_hello.extension_length = {
@ -244,10 +255,25 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
}
// Generic inner recv method, through TCP socket
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<TlsRepr<'a>> {
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<Vec::<TlsRepr>> {
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
let size = tcp_socket.recv_slice(byte_array)?;
todo!()
tcp_socket.recv_slice(byte_array)?;
let mut vec: Vec<TlsRepr> = Vec::new();
let mut bytes: &[u8] = byte_array;
loop {
match parse_tls_repr(bytes) {
Ok((rest, repr)) => {
vec.push(repr);
if rest.len() == 0 {
return Ok(vec);
} else {
bytes = rest;
}
},
_ => return Err(Error::Unrecognized),
};
}
}
}
@ -317,7 +343,7 @@ impl<'a> TlsBuffer<'a> {
Ok(slice)
}
fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr) -> Result<()> {
fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr<'a>) -> Result<()> {
self.write_u8(tls_repr.content_type.into())?;
self.write_u16(tls_repr.version.into())?;
self.write_u16(tls_repr.length)?;
@ -332,13 +358,13 @@ impl<'a> TlsBuffer<'a> {
Ok(())
}
fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr) -> Result<()> {
fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr<'a>) -> Result<()> {
self.write_u8(handshake_repr.msg_type.into())?;
self.write_u24(handshake_repr.length)?;
self.enqueue_handshake_data(handshake_repr.handshake_data)
}
fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData) -> Result<()> {
fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData<'a>) -> Result<()> {
match handshake_data {
HandshakeData::ClientHello(client_hello) => {
self.enqueue_client_hello(client_hello)
@ -349,7 +375,7 @@ impl<'a> TlsBuffer<'a> {
}
}
fn enqueue_client_hello(&mut self, client_hello: ClientHello) -> Result<()> {
fn enqueue_client_hello(&mut self, client_hello: ClientHello<'a>) -> Result<()> {
self.write_u16(client_hello.version.into())?;
self.write(&client_hello.random)?;
self.write_u8(client_hello.session_id_length)?;
@ -364,7 +390,7 @@ impl<'a> TlsBuffer<'a> {
self.enqueue_extensions(client_hello.extensions)
}
fn enqueue_extensions(&mut self, extensions: &[Extension]) -> Result<()> {
fn enqueue_extensions(&mut self, extensions: Vec<Extension<'a>>) -> Result<()> {
for extension in extensions {
self.write_u16(extension.extension_type.into())?;
self.write_u16(extension.length)?;

View File

@ -1,9 +1,15 @@
use byteorder::{ByteOrder, NetworkEndian, BigEndian};
use num_enum::IntoPrimitive;
use num_enum::TryFromPrimitive;
use rand_core::RngCore;
use rand_core::CryptoRng;
use core::convert::TryFrom;
use core::convert::TryInto;
use alloc::vec::Vec;
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
pub(crate) enum TlsContentType {
@ -26,7 +32,7 @@ pub(crate) enum TlsVersion {
Tls13 = 0x0304,
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub(crate) struct TlsRepr<'a> {
pub(crate) content_type: TlsContentType,
pub(crate) version: TlsVersion,
@ -35,6 +41,25 @@ pub(crate) struct TlsRepr<'a> {
pub(crate) handshake: Option<HandshakeRepr<'a>>
}
impl<'a> TlsRepr<'a> {
pub(crate) fn new() -> Self {
TlsRepr {
content_type: TlsContentType::Invalid,
version: TlsVersion::Tls12,
length: 0,
payload: None,
handshake: None,
}
}
pub(crate) fn client_hello(mut self) -> Self {
self.content_type = TlsContentType::Handshake;
self.version = TlsVersion::Tls10;
// TODO: Fill in handshake field
self
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
pub(crate) enum HandshakeType {
@ -53,7 +78,7 @@ pub(crate) enum HandshakeType {
MessageHash = 254,
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub(crate) struct HandshakeRepr<'a> {
pub(crate) msg_type: HandshakeType,
pub(crate) length: u32,
@ -61,6 +86,14 @@ pub(crate) struct HandshakeRepr<'a> {
}
impl<'a, 'b> HandshakeRepr<'a> {
pub(self) fn new() -> Self {
HandshakeRepr {
msg_type: HandshakeType::Unknown,
length: 0,
handshake_data: HandshakeData::Uninitialized,
}
}
pub(crate) fn get_length(&self) -> u16 {
let mut length :u16 = 1; // Handshake Type
length += 3; // Length of Handshake data
@ -70,6 +103,7 @@ impl<'a, 'b> HandshakeRepr<'a> {
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
#[allow(non_camel_case)]
#[repr(u16)]
pub(crate) enum CipherSuite {
TLS_AES_128_GCM_SHA256 = 0x1301,
@ -79,7 +113,7 @@ pub(crate) enum CipherSuite {
TLS_AES_128_CCM_8_SHA256 = 0x1305,
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub(crate) struct ClientHello<'a> {
pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303)
pub(crate) random: [u8; 32],
@ -90,10 +124,10 @@ pub(crate) struct ClientHello<'a> {
pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte
pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0
pub(crate) extension_length: u16,
pub(crate) extensions: &'a[Extension<'a>],
pub(crate) extensions: Vec<Extension<'a>>,
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub(crate) enum HandshakeData<'a> {
Uninitialized,
ClientHello(ClientHello<'a>),
@ -111,6 +145,33 @@ impl<'a> HandshakeData<'a> {
}
impl<'a> ClientHello<'a> {
pub(self) fn new<T>(rng: &mut T) -> Self
where
T: RngCore + CryptoRng
{
let mut client_hello = ClientHello {
version: TlsVersion::Tls12,
random: [0; 32],
session_id_length: 32,
session_id: [0; 32],
cipher_suites_length: 0,
cipher_suites: &[
CipherSuite::TLS_AES_128_GCM_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384,
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
],
compression_method_length: 1,
compression_methods: 0,
extension_length: 0,
extensions: Vec::new(),
};
rng.fill_bytes(&mut client_hello.random);
rng.fill_bytes(&mut client_hello.session_id);
client_hello
}
pub(crate) fn get_length(&self) -> u32 {
let mut length :u32 = 2; // TlsVersion size
length += 32; // Random size
@ -128,7 +189,7 @@ impl<'a> ClientHello<'a> {
}
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub(crate) struct ServerHello<'a> {
pub(crate) version: TlsVersion,
pub(crate) random: &'a[u8],
@ -137,7 +198,7 @@ pub(crate) struct ServerHello<'a> {
pub(crate) cipher_suite: CipherSuite,
pub(crate) compression_method: u8, // Always 0
pub(crate) extension_length: u16,
pub(crate) extensions: &'a[Extension<'a>],
pub(crate) extensions: Vec<Extension<'a>>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]