From 1ad8f9c9bd0689ffc0200b41328f6a0c967f4767 Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 20 Dec 2016 22:57:21 +0000 Subject: [PATCH] Implement conversion of incoming TCP connections into TCP streams. --- examples/smoltcpserver.rs | 16 +++-- src/iface/ethernet.rs | 8 +-- src/socket/mod.rs | 24 ++++--- src/socket/tcp.rs | 127 +++++++++++++++++++++++++++++++++----- src/socket/udp.rs | 2 +- 5 files changed, 142 insertions(+), 35 deletions(-) diff --git a/examples/smoltcpserver.rs b/examples/smoltcpserver.rs index 24ca124..4c4ee5b 100644 --- a/examples/smoltcpserver.rs +++ b/examples/smoltcpserver.rs @@ -7,7 +7,7 @@ use smoltcp::phy::{Tracer, TapInterface}; use smoltcp::wire::{EthernetFrame, EthernetAddress, IpAddress, IpEndpoint}; use smoltcp::iface::{SliceArpCache, EthernetInterface}; use smoltcp::socket::{UdpSocket, AsSocket, UdpSocketBuffer, UdpPacketBuffer}; -use smoltcp::socket::{TcpListener}; +use smoltcp::socket::{TcpListener, TcpStreamBuffer}; fn main() { let ifname = env::args().nth(1).unwrap(); @@ -27,7 +27,7 @@ fn main() { let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)]; - let sockets = [udp_socket, tcp_listener]; + let sockets = vec![udp_socket, tcp_listener]; let mut iface = EthernetInterface::new(device, arp_cache, hardware_addr, protocol_addrs, sockets); @@ -57,11 +57,15 @@ fn main() { } } - { + if let Some(incoming) = { let tcp_listener: &mut TcpListener = iface.sockets()[1].as_socket(); - if let Some(stream) = tcp_listener.accept() { - println!("client from {}", stream.remote_end()) - } + tcp_listener.accept() + } { + println!("client from {}", incoming.remote_end()); + + let tcp_rx_buffer = TcpStreamBuffer::new(vec![0; 8192]); + let tcp_tx_buffer = TcpStreamBuffer::new(vec![0; 4096]); + iface.sockets().push(incoming.into_stream(tcp_rx_buffer, tcp_tx_buffer)); } } } diff --git a/src/iface/ethernet.rs b/src/iface/ethernet.rs index 8da9313..7162573 100644 --- a/src/iface/ethernet.rs +++ b/src/iface/ethernet.rs @@ -95,8 +95,8 @@ impl<'a, 'b: 'a, /// /// # Panics /// This function panics if any of the addresses is not unicast. - pub fn update_protocol_addrs(&mut self, f: F) { - f(self.protocol_addrs.borrow_mut()); + pub fn update_protocol_addrs(&mut self, f: F) { + f(&mut self.protocol_addrs); Self::check_protocol_addrs(self.protocol_addrs.borrow()) } @@ -107,8 +107,8 @@ impl<'a, 'b: 'a, } /// Get the set of sockets owned by the interface. - pub fn sockets(&mut self) -> &mut [Socket<'a, 'b>] { - self.sockets.borrow_mut() + pub fn sockets(&mut self) -> &mut SocketsT { + &mut self.sockets } /// Receive and process a packet, if available. diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 867398b..6a5e732 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -20,7 +20,8 @@ pub use self::udp::PacketBuffer as UdpPacketBuffer; pub use self::udp::SocketBuffer as UdpSocketBuffer; pub use self::udp::UdpSocket as UdpSocket; -pub use self::tcp::SocketBuffer as TcpSocketBuffer; +pub use self::tcp::StreamBuffer as TcpStreamBuffer; +pub use self::tcp::Stream as TcpStream; pub use self::tcp::Incoming as TcpIncoming; pub use self::tcp::Listener as TcpListener; @@ -50,8 +51,9 @@ pub trait PacketRepr { /// not yet known and the packet storage has to be allocated; but the `&PacketRepr` is sufficient /// since the lower layers treat the packet as an opaque octet sequence. pub enum Socket<'a, 'b: 'a> { - Udp(UdpSocket<'a, 'b>), - TcpServer(TcpListener<'a>), + UdpSocket(UdpSocket<'a, 'b>), + TcpStream(TcpStream<'a>), + TcpListener(TcpListener<'a>), #[doc(hidden)] __Nonexhaustive } @@ -68,9 +70,11 @@ impl<'a, 'b> Socket<'a, 'b> { protocol: IpProtocol, payload: &[u8]) -> Result<(), Error> { match self { - &mut Socket::Udp(ref mut socket) => + &mut Socket::UdpSocket(ref mut socket) => socket.collect(src_addr, dst_addr, protocol, payload), - &mut Socket::TcpServer(ref mut socket) => + &mut Socket::TcpStream(ref mut socket) => + socket.collect(src_addr, dst_addr, protocol, payload), + &mut Socket::TcpListener(ref mut socket) => socket.collect(src_addr, dst_addr, protocol, payload), &mut Socket::__Nonexhaustive => unreachable!() } @@ -87,9 +91,11 @@ impl<'a, 'b> Socket<'a, 'b> { IpProtocol, &PacketRepr) -> Result<(), Error>) -> Result<(), Error> { match self { - &mut Socket::Udp(ref mut socket) => + &mut Socket::UdpSocket(ref mut socket) => socket.dispatch(f), - &mut Socket::TcpServer(_) => + &mut Socket::TcpStream(ref mut socket) => + socket.dispatch(f), + &mut Socket::TcpListener(_) => Err(Error::Exhausted), &mut Socket::__Nonexhaustive => unreachable!() } @@ -107,7 +113,7 @@ pub trait AsSocket { impl<'a, 'b> AsSocket> for Socket<'a, 'b> { fn as_socket(&mut self) -> &mut UdpSocket<'a, 'b> { match self { - &mut Socket::Udp(ref mut socket) => socket, + &mut Socket::UdpSocket(ref mut socket) => socket, _ => panic!(".as_socket:: called on wrong socket type") } } @@ -116,7 +122,7 @@ impl<'a, 'b> AsSocket> for Socket<'a, 'b> { impl<'a, 'b> AsSocket> for Socket<'a, 'b> { fn as_socket(&mut self) -> &mut TcpListener<'a> { match self { - &mut Socket::TcpServer(ref mut socket) => socket, + &mut Socket::TcpListener(ref mut socket) => socket, _ => panic!(".as_socket:: called on wrong socket type") } } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index cf6d59f..7bfefe6 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -2,21 +2,21 @@ use Error; use Managed; use wire::{IpProtocol, IpAddress, IpEndpoint}; use wire::{TcpPacket, TcpRepr, TcpControl}; -use socket::{Socket}; +use socket::{Socket, PacketRepr}; /// A TCP stream ring buffer. #[derive(Debug)] -pub struct SocketBuffer<'a> { +pub struct StreamBuffer<'a> { storage: Managed<'a, [u8]>, read_at: usize, length: usize } -impl<'a> SocketBuffer<'a> { +impl<'a> StreamBuffer<'a> { /// Create a packet buffer with the given storage. - pub fn new(storage: T) -> SocketBuffer<'a> + pub fn new(storage: T) -> StreamBuffer<'a> where T: Into> { - SocketBuffer { + StreamBuffer { storage: storage.into(), read_at: 0, length: 0 @@ -60,24 +60,119 @@ impl<'a> SocketBuffer<'a> { } } -/// A description of incoming TCP connection. -#[derive(Debug)] -pub struct Incoming { - local_end: IpEndpoint, - remote_end: IpEndpoint, - seq_number: u32 +impl<'a> Into> for Managed<'a, [u8]> { + fn into(self) -> StreamBuffer<'a> { + StreamBuffer::new(self) + } } -impl Incoming { +/// A Transmission Control Protocol data stream. +#[derive(Debug)] +pub struct Stream<'a> { + local_end: IpEndpoint, + remote_end: IpEndpoint, + local_seq: u32, + remote_seq: u32, + rx_buffer: StreamBuffer<'a>, + tx_buffer: StreamBuffer<'a> +} + +impl<'a> Stream<'a> { /// Return the local endpoint. + #[inline(always)] pub fn local_end(&self) -> IpEndpoint { self.local_end } /// Return the remote endpoint. + #[inline(always)] pub fn remote_end(&self) -> IpEndpoint { self.remote_end } + + /// See [Socket::collect](enum.Socket.html#method.collect). + pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress, + protocol: IpProtocol, payload: &[u8]) + -> Result<(), Error> { + if protocol != IpProtocol::Tcp { return Err(Error::Rejected) } + + let packet = try!(TcpPacket::new(payload)); + let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr)); + + if self.local_end != IpEndpoint::new(*dst_addr, repr.dst_port) { + return Err(Error::Rejected) + } + if self.remote_end != IpEndpoint::new(*src_addr, repr.src_port) { + return Err(Error::Rejected) + } + + // FIXME: process + Ok(()) + } + + /// See [Socket::dispatch](enum.Socket.html#method.dispatch). + pub fn dispatch(&mut self, _f: &mut FnMut(&IpAddress, &IpAddress, + IpProtocol, &PacketRepr) -> Result<(), Error>) + -> Result<(), Error> { + // FIXME: process + // f(&self.local_end.addr, + // &self.remote_end.addr, + // IpProtocol::Tcp, + // &TcpRepr { + // src_port: self.local_end.port, + // dst_port: self.remote_end.port, + // payload: &packet_buf.as_ref()[..] + // }) + + Ok(()) + } +} + +impl<'a> PacketRepr for TcpRepr<'a> { + fn buffer_len(&self) -> usize { + self.buffer_len() + } + + fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, payload: &mut [u8]) { + let mut packet = TcpPacket::new(payload).expect("undersized payload"); + self.emit(&mut packet, src_addr, dst_addr) + } +} + +/// A description of incoming TCP connection. +#[derive(Debug)] +pub struct Incoming { + local_end: IpEndpoint, + remote_end: IpEndpoint, + local_seq: u32, + remote_seq: u32 +} + +impl Incoming { + /// Return the local endpoint. + #[inline(always)] + pub fn local_end(&self) -> IpEndpoint { + self.local_end + } + + /// Return the remote endpoint. + #[inline(always)] + pub fn remote_end(&self) -> IpEndpoint { + self.remote_end + } + + /// Convert into a data stream using the given buffers. + pub fn into_stream<'a, T>(self, rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static> + where T: Into> { + Socket::TcpStream(Stream { + rx_buffer: rx_buffer.into(), + tx_buffer: tx_buffer.into(), + local_end: self.local_end, + remote_end: self.remote_end, + local_seq: self.local_seq, + remote_seq: self.remote_seq + }) + } } /// A Transmission Control Protocol server socket. @@ -93,7 +188,7 @@ impl<'a> Listener<'a> { /// Create a server socket with the given backlog. pub fn new(endpoint: IpEndpoint, backlog: T) -> Socket<'a, 'static> where T: Into]>> { - Socket::TcpServer(Listener { + Socket::TcpListener(Listener { endpoint: endpoint, backlog: backlog.into(), accept_at: 0, @@ -137,7 +232,9 @@ impl<'a> Listener<'a> { self.backlog[inject_at] = Some(Incoming { local_end: IpEndpoint::new(*dst_addr, repr.dst_port), remote_end: IpEndpoint::new(*src_addr, repr.src_port), - seq_number: repr.seq_number + // FIXME: choose something more secure? + local_seq: !repr.seq_number, + remote_seq: repr.seq_number }); Ok(()) } @@ -152,7 +249,7 @@ mod test { #[test] fn test_buffer() { - let mut buffer = SocketBuffer::new(vec![0; 8]); // ........ + let mut buffer = StreamBuffer::new(vec![0; 8]); // ........ buffer.enqueue(6).copy_from_slice(b"foobar"); // foobar.. assert_eq!(buffer.dequeue(3), b"foo"); // ...bar.. buffer.enqueue(6).copy_from_slice(b"ba"); // ...barba diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 9b40a92..26adfaa 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -115,7 +115,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> { pub fn new(endpoint: IpEndpoint, rx_buffer: SocketBuffer<'a, 'b>, tx_buffer: SocketBuffer<'a, 'b>) -> Socket<'a, 'b> { - Socket::Udp(UdpSocket { + Socket::UdpSocket(UdpSocket { endpoint: endpoint, rx_buffer: rx_buffer, tx_buffer: tx_buffer