tls: insert iface polling
This commit is contained in:
parent
ba40c0780c
commit
70527da8c4
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "smoltcp-tls"
|
||||
version = "0.1.0"
|
||||
authors = ["occheung <occheung@connect.ust.hk>"]
|
||||
authors = ["occheung"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
|
@ -15,7 +15,7 @@ log = {version = "0.4.11"}
|
|||
[dependencies.smoltcp]
|
||||
version = "0.6.0"
|
||||
default-features = false
|
||||
features = ["proto-ipv4", "proto-ipv6", "socket-tcp"]
|
||||
features = ["ethernet", "proto-ipv4", "proto-ipv6", "socket-tcp", "alloc"]
|
||||
|
||||
[dependencies.rand_core]
|
||||
version = "0.5.1"
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
#![no_std]
|
||||
|
||||
#[macro_use]
|
||||
extern crate alloc;
|
||||
|
||||
pub mod tls;
|
||||
pub mod tls_packet;
|
||||
pub mod parse;
|
||||
|
||||
// TODO: Implement errors
|
||||
// Details: Encapsulate smoltcp & nom errors
|
||||
pub enum Error {
|
||||
PropagatedError(smoltcp::Error),
|
||||
ParsingError()
|
||||
|
|
35
src/main.rs
35
src/main.rs
|
@ -3,6 +3,36 @@ use smoltcp::socket::TcpSocketBuffer;
|
|||
use smoltcp::socket::SocketSet;
|
||||
use smoltcp::wire::Ipv4Address;
|
||||
|
||||
use rand_core::RngCore;
|
||||
use rand_core::CryptoRng;
|
||||
use rand_core::impls;
|
||||
use rand_core::Error;
|
||||
|
||||
struct CountingRng(u64);
|
||||
|
||||
impl RngCore for CountingRng {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.next_u64() as u32
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
self.0 += 1;
|
||||
self.0
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
impls::fill_bytes_via_next(self, dest)
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
|
||||
Ok(self.fill_bytes(dest))
|
||||
}
|
||||
}
|
||||
|
||||
impl CryptoRng for CountingRng {}
|
||||
|
||||
static mut RNG: CountingRng = CountingRng(0);
|
||||
|
||||
fn main() {
|
||||
let mut socket_set_entries: [_; 8] = Default::default();
|
||||
let mut sockets = SocketSet::new(&mut socket_set_entries[..]);
|
||||
|
@ -10,13 +40,14 @@ fn main() {
|
|||
let mut tx_storage = [0; 4096];
|
||||
let mut rx_storage = [0; 4096];
|
||||
|
||||
let mut tls_socket = {
|
||||
let mut tls_socket = unsafe {
|
||||
let tx_buffer = TcpSocketBuffer::new(&mut tx_storage[..]);
|
||||
let rx_buffer = TcpSocketBuffer::new(&mut rx_storage[..]);
|
||||
TlsSocket::new(
|
||||
&mut sockets,
|
||||
rx_buffer,
|
||||
tx_buffer
|
||||
tx_buffer,
|
||||
&mut RNG,
|
||||
)
|
||||
};
|
||||
|
||||
|
|
23
src/parse.rs
23
src/parse.rs
|
@ -11,9 +11,9 @@ use byteorder::{ByteOrder, NetworkEndian, BigEndian};
|
|||
use crate::tls_packet::*;
|
||||
use core::convert::TryFrom;
|
||||
|
||||
use heapless::{ Vec, consts::* };
|
||||
use alloc::vec::Vec;
|
||||
|
||||
fn parse_tls(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
|
||||
pub(crate) fn parse_tls_repr(bytes: &[u8]) -> IResult<&[u8], TlsRepr> {
|
||||
let content_type = take(1_usize);
|
||||
let version = take(2_usize);
|
||||
let length = take(2_usize);
|
||||
|
@ -65,7 +65,9 @@ fn parse_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeRepr> {
|
|||
use crate::tls_packet::HandshakeType::*;
|
||||
match repr.msg_type {
|
||||
ServerHello => {
|
||||
todo!()
|
||||
let (rest, data) = parse_server_hello(bytes)?;
|
||||
repr.handshake_data = data;
|
||||
Ok((rest, repr))
|
||||
},
|
||||
_ => todo!()
|
||||
}
|
||||
|
@ -90,27 +92,26 @@ fn parse_server_hello(bytes: &[u8]) -> IResult<&[u8], HandshakeData> {
|
|||
let (mut rest, (cipher_suite, compression_method, extension_length)) =
|
||||
tuple((cipher_suite, compression_method, extension_length))(rest)?;
|
||||
|
||||
let mut extension_length = NetworkEndian::read_u16(extension_length);
|
||||
|
||||
let mut server_hello = ServerHello {
|
||||
version: TlsVersion::try_from(NetworkEndian::read_u16(version)).unwrap(),
|
||||
random,
|
||||
session_id_echo_length: session_id_echo_length,
|
||||
session_id_echo_length,
|
||||
session_id_echo,
|
||||
cipher_suite: CipherSuite::try_from(NetworkEndian::read_u16(cipher_suite)).unwrap(),
|
||||
compression_method: compression_method[0],
|
||||
extension_length,
|
||||
extensions: &[]
|
||||
extension_length: NetworkEndian::read_u16(extension_length),
|
||||
extensions: Vec::new(),
|
||||
};
|
||||
|
||||
let mut extension_vec: Vec<Extension, U32> = Vec::new();
|
||||
let mut extension_vec: Vec<Extension> = Vec::new();
|
||||
let mut extension_length: i32 = server_hello.extension_length.into();
|
||||
while extension_length >= 0 {
|
||||
let (rem, extension) = parse_extension(rest)?;
|
||||
rest = rem;
|
||||
extension_length -= extension.get_length();
|
||||
extension_length -= i32::try_from(extension.get_length()).unwrap();
|
||||
|
||||
// Todo:: Proper error
|
||||
if extension_vec.push(extension).is_err() || extension_length < 0 {
|
||||
if extension_length < 0 {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
|
56
src/tls.rs
56
src/tls.rs
|
@ -10,19 +10,22 @@ use smoltcp::wire::Ipv4Address;
|
|||
use smoltcp::wire::IpEndpoint;
|
||||
use smoltcp::Result;
|
||||
use smoltcp::Error;
|
||||
use smoltcp::iface::EthernetInterface;
|
||||
use smoltcp::time::Instant;
|
||||
use smoltcp::phy::Device;
|
||||
|
||||
use byteorder::{ByteOrder, NetworkEndian, BigEndian};
|
||||
|
||||
use heapless::Vec;
|
||||
use heapless::consts::*;
|
||||
|
||||
use core::convert::TryInto;
|
||||
use core::convert::TryFrom;
|
||||
|
||||
use rand_core::{RngCore, CryptoRng};
|
||||
use p256::{EncodedPoint, AffinePoint, ecdh::EphemeralSecret};
|
||||
|
||||
use alloc::vec::{ self, Vec };
|
||||
|
||||
use crate::tls_packet::*;
|
||||
use crate::parse::parse_tls_repr;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[allow(non_camel_case_types)]
|
||||
|
@ -81,7 +84,15 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn tls_connect(&mut self, sockets: &mut SocketSet) -> Result<bool> {
|
||||
pub fn tls_connect<DeviceT>(
|
||||
&mut self,
|
||||
iface: EthernetInterface<DeviceT>,
|
||||
sockets: &mut SocketSet,
|
||||
now: Instant
|
||||
) -> Result<bool>
|
||||
where
|
||||
DeviceT: for<'d> Device<'d>
|
||||
{
|
||||
// Check tcp_socket connectivity
|
||||
{
|
||||
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
|
||||
|
@ -183,13 +194,13 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
|
|||
compression_method_length: 1,
|
||||
compression_methods: 0,
|
||||
extension_length: supported_versions_extension.get_length(),
|
||||
extensions: &[
|
||||
extensions: vec![
|
||||
supported_versions_extension,
|
||||
signature_algorithms_extension,
|
||||
supported_groups_extension,
|
||||
psk_key_exchange_modes_extension,
|
||||
key_share_extension,
|
||||
],
|
||||
key_share_extension
|
||||
]
|
||||
};
|
||||
|
||||
client_hello.extension_length = {
|
||||
|
@ -244,10 +255,25 @@ impl<R: RngCore + CryptoRng> TlsSocket<R> {
|
|||
}
|
||||
|
||||
// Generic inner recv method, through TCP socket
|
||||
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<TlsRepr<'a>> {
|
||||
fn recv_tls_repr<'a>(&'a mut self, sockets: &mut SocketSet, byte_array: &'a mut [u8]) -> Result<Vec::<TlsRepr>> {
|
||||
let mut tcp_socket = sockets.get::<TcpSocket>(self.tcp_handle);
|
||||
let size = tcp_socket.recv_slice(byte_array)?;
|
||||
todo!()
|
||||
tcp_socket.recv_slice(byte_array)?;
|
||||
let mut vec: Vec<TlsRepr> = Vec::new();
|
||||
|
||||
let mut bytes: &[u8] = byte_array;
|
||||
loop {
|
||||
match parse_tls_repr(bytes) {
|
||||
Ok((rest, repr)) => {
|
||||
vec.push(repr);
|
||||
if rest.len() == 0 {
|
||||
return Ok(vec);
|
||||
} else {
|
||||
bytes = rest;
|
||||
}
|
||||
},
|
||||
_ => return Err(Error::Unrecognized),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -317,7 +343,7 @@ impl<'a> TlsBuffer<'a> {
|
|||
Ok(slice)
|
||||
}
|
||||
|
||||
fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr) -> Result<()> {
|
||||
fn enqueue_tls_repr(&mut self, tls_repr: TlsRepr<'a>) -> Result<()> {
|
||||
self.write_u8(tls_repr.content_type.into())?;
|
||||
self.write_u16(tls_repr.version.into())?;
|
||||
self.write_u16(tls_repr.length)?;
|
||||
|
@ -332,13 +358,13 @@ impl<'a> TlsBuffer<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr) -> Result<()> {
|
||||
fn enqueue_handshake_repr(&mut self, handshake_repr: HandshakeRepr<'a>) -> Result<()> {
|
||||
self.write_u8(handshake_repr.msg_type.into())?;
|
||||
self.write_u24(handshake_repr.length)?;
|
||||
self.enqueue_handshake_data(handshake_repr.handshake_data)
|
||||
}
|
||||
|
||||
fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData) -> Result<()> {
|
||||
fn enqueue_handshake_data(&mut self, handshake_data: HandshakeData<'a>) -> Result<()> {
|
||||
match handshake_data {
|
||||
HandshakeData::ClientHello(client_hello) => {
|
||||
self.enqueue_client_hello(client_hello)
|
||||
|
@ -349,7 +375,7 @@ impl<'a> TlsBuffer<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn enqueue_client_hello(&mut self, client_hello: ClientHello) -> Result<()> {
|
||||
fn enqueue_client_hello(&mut self, client_hello: ClientHello<'a>) -> Result<()> {
|
||||
self.write_u16(client_hello.version.into())?;
|
||||
self.write(&client_hello.random)?;
|
||||
self.write_u8(client_hello.session_id_length)?;
|
||||
|
@ -364,7 +390,7 @@ impl<'a> TlsBuffer<'a> {
|
|||
self.enqueue_extensions(client_hello.extensions)
|
||||
}
|
||||
|
||||
fn enqueue_extensions(&mut self, extensions: &[Extension]) -> Result<()> {
|
||||
fn enqueue_extensions(&mut self, extensions: Vec<Extension<'a>>) -> Result<()> {
|
||||
for extension in extensions {
|
||||
self.write_u16(extension.extension_type.into())?;
|
||||
self.write_u16(extension.length)?;
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
use byteorder::{ByteOrder, NetworkEndian, BigEndian};
|
||||
use num_enum::IntoPrimitive;
|
||||
use num_enum::TryFromPrimitive;
|
||||
|
||||
use rand_core::RngCore;
|
||||
use rand_core::CryptoRng;
|
||||
|
||||
use core::convert::TryFrom;
|
||||
use core::convert::TryInto;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum TlsContentType {
|
||||
|
@ -26,7 +32,7 @@ pub(crate) enum TlsVersion {
|
|||
Tls13 = 0x0304,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct TlsRepr<'a> {
|
||||
pub(crate) content_type: TlsContentType,
|
||||
pub(crate) version: TlsVersion,
|
||||
|
@ -35,6 +41,25 @@ pub(crate) struct TlsRepr<'a> {
|
|||
pub(crate) handshake: Option<HandshakeRepr<'a>>
|
||||
}
|
||||
|
||||
impl<'a> TlsRepr<'a> {
|
||||
pub(crate) fn new() -> Self {
|
||||
TlsRepr {
|
||||
content_type: TlsContentType::Invalid,
|
||||
version: TlsVersion::Tls12,
|
||||
length: 0,
|
||||
payload: None,
|
||||
handshake: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn client_hello(mut self) -> Self {
|
||||
self.content_type = TlsContentType::Handshake;
|
||||
self.version = TlsVersion::Tls10;
|
||||
// TODO: Fill in handshake field
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum HandshakeType {
|
||||
|
@ -53,7 +78,7 @@ pub(crate) enum HandshakeType {
|
|||
MessageHash = 254,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct HandshakeRepr<'a> {
|
||||
pub(crate) msg_type: HandshakeType,
|
||||
pub(crate) length: u32,
|
||||
|
@ -61,6 +86,14 @@ pub(crate) struct HandshakeRepr<'a> {
|
|||
}
|
||||
|
||||
impl<'a, 'b> HandshakeRepr<'a> {
|
||||
pub(self) fn new() -> Self {
|
||||
HandshakeRepr {
|
||||
msg_type: HandshakeType::Unknown,
|
||||
length: 0,
|
||||
handshake_data: HandshakeData::Uninitialized,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_length(&self) -> u16 {
|
||||
let mut length :u16 = 1; // Handshake Type
|
||||
length += 3; // Length of Handshake data
|
||||
|
@ -70,6 +103,7 @@ impl<'a, 'b> HandshakeRepr<'a> {
|
|||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
|
||||
#[allow(non_camel_case)]
|
||||
#[repr(u16)]
|
||||
pub(crate) enum CipherSuite {
|
||||
TLS_AES_128_GCM_SHA256 = 0x1301,
|
||||
|
@ -79,7 +113,7 @@ pub(crate) enum CipherSuite {
|
|||
TLS_AES_128_CCM_8_SHA256 = 0x1305,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ClientHello<'a> {
|
||||
pub(crate) version: TlsVersion, // Legacy: Must be Tls12 (0x0303)
|
||||
pub(crate) random: [u8; 32],
|
||||
|
@ -90,10 +124,10 @@ pub(crate) struct ClientHello<'a> {
|
|||
pub(crate) compression_method_length: u8, // Legacy: Must be 1, to contain a byte
|
||||
pub(crate) compression_methods: u8, // Legacy: Must be 1 byte of 0
|
||||
pub(crate) extension_length: u16,
|
||||
pub(crate) extensions: &'a[Extension<'a>],
|
||||
pub(crate) extensions: Vec<Extension<'a>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum HandshakeData<'a> {
|
||||
Uninitialized,
|
||||
ClientHello(ClientHello<'a>),
|
||||
|
@ -111,6 +145,33 @@ impl<'a> HandshakeData<'a> {
|
|||
}
|
||||
|
||||
impl<'a> ClientHello<'a> {
|
||||
pub(self) fn new<T>(rng: &mut T) -> Self
|
||||
where
|
||||
T: RngCore + CryptoRng
|
||||
{
|
||||
let mut client_hello = ClientHello {
|
||||
version: TlsVersion::Tls12,
|
||||
random: [0; 32],
|
||||
session_id_length: 32,
|
||||
session_id: [0; 32],
|
||||
cipher_suites_length: 0,
|
||||
cipher_suites: &[
|
||||
CipherSuite::TLS_AES_128_GCM_SHA256,
|
||||
CipherSuite::TLS_AES_256_GCM_SHA384,
|
||||
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
|
||||
],
|
||||
compression_method_length: 1,
|
||||
compression_methods: 0,
|
||||
extension_length: 0,
|
||||
extensions: Vec::new(),
|
||||
};
|
||||
|
||||
rng.fill_bytes(&mut client_hello.random);
|
||||
rng.fill_bytes(&mut client_hello.session_id);
|
||||
|
||||
client_hello
|
||||
}
|
||||
|
||||
pub(crate) fn get_length(&self) -> u32 {
|
||||
let mut length :u32 = 2; // TlsVersion size
|
||||
length += 32; // Random size
|
||||
|
@ -128,7 +189,7 @@ impl<'a> ClientHello<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ServerHello<'a> {
|
||||
pub(crate) version: TlsVersion,
|
||||
pub(crate) random: &'a[u8],
|
||||
|
@ -137,7 +198,7 @@ pub(crate) struct ServerHello<'a> {
|
|||
pub(crate) cipher_suite: CipherSuite,
|
||||
pub(crate) compression_method: u8, // Always 0
|
||||
pub(crate) extension_length: u16,
|
||||
pub(crate) extensions: &'a[Extension<'a>],
|
||||
pub(crate) extensions: Vec<Extension<'a>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, IntoPrimitive, TryFromPrimitive)]
|
||||
|
|
Loading…
Reference in New Issue