From bddb5f9127307b029e373e18c8233cd21622d43b Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 20 Dec 2016 19:51:52 +0000 Subject: [PATCH] Implement TCP server sockets. --- examples/smoltcpserver.rs | 46 +++++++++++++------- src/managed.rs | 4 +- src/socket/mod.rs | 16 +++++++ src/socket/tcp.rs | 90 +++++++++++++++++++++++++++++++++++++++ src/socket/udp.rs | 2 +- src/wire/ip.rs | 20 ++++----- 6 files changed, 149 insertions(+), 29 deletions(-) diff --git a/examples/smoltcpserver.rs b/examples/smoltcpserver.rs index 54a2051..24ca124 100644 --- a/examples/smoltcpserver.rs +++ b/examples/smoltcpserver.rs @@ -7,6 +7,7 @@ use smoltcp::phy::{Tracer, TapInterface}; use smoltcp::wire::{EthernetFrame, EthernetAddress, IpAddress, IpEndpoint}; use smoltcp::iface::{SliceArpCache, EthernetInterface}; use smoltcp::socket::{UdpSocket, AsSocket, UdpSocketBuffer, UdpPacketBuffer}; +use smoltcp::socket::{TcpListener}; fn main() { let ifname = env::args().nth(1).unwrap(); @@ -15,14 +16,18 @@ fn main() { let device = Tracer::<_, EthernetFrame<&[u8]>>::new(device); let arp_cache = SliceArpCache::new(vec![Default::default(); 8]); + let endpoint = IpEndpoint::new(IpAddress::default(), 6969); + let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 2048])]); let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketBuffer::new(vec![0; 2048])]); - let endpoint = IpEndpoint::new(IpAddress::default(), 6969); let udp_socket = UdpSocket::new(endpoint, udp_rx_buffer, udp_tx_buffer); + let tcp_backlog = vec![None]; + let tcp_listener = TcpListener::new(endpoint, tcp_backlog); + let hardware_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); let protocol_addrs = [IpAddress::v4(192, 168, 69, 1)]; - let sockets = [udp_socket]; + let sockets = [udp_socket, tcp_listener]; let mut iface = EthernetInterface::new(device, arp_cache, hardware_addr, protocol_addrs, sockets); @@ -32,22 +37,31 @@ fn main() { Err(e) => println!("error {}", e) } - let udp_socket = iface.sockets()[0].as_socket(); - let client = match udp_socket.recv() { - Ok((endpoint, data)) => { - println!("data {:?} from {}", data, endpoint); - Some(endpoint) + { + let udp_socket: &mut UdpSocket = iface.sockets()[0].as_socket(); + let udp_client = match udp_socket.recv() { + Ok((endpoint, data)) => { + println!("data {:?} from {}", data, endpoint); + Some(endpoint) + } + Err(Error::Exhausted) => { + None + } + Err(e) => { + println!("error {}", e); + None + } + }; + if let Some(endpoint) = udp_client { + udp_socket.send_slice(endpoint, "hihihi".as_bytes()).unwrap() } - Err(Error::Exhausted) => { - None + } + + { + let tcp_listener: &mut TcpListener = iface.sockets()[1].as_socket(); + if let Some(stream) = tcp_listener.accept() { + println!("client from {}", stream.remote_end()) } - Err(e) => { - println!("error {}", e); - None - } - }; - if let Some(endpoint) = client { - udp_socket.send_slice(endpoint, "hihihi".as_bytes()).unwrap() } } } diff --git a/src/managed.rs b/src/managed.rs index 0ea04e1..23dfdec 100644 --- a/src/managed.rs +++ b/src/managed.rs @@ -35,8 +35,8 @@ impl<'a, T: 'a + fmt::Debug + ?Sized> fmt::Debug for Managed<'a, T> { } } -impl<'a, 'b: 'a, T: 'b + ?Sized> From<&'b mut T> for Managed<'b, T> { - fn from(value: &'b mut T) -> Self { +impl<'a, T: 'a + ?Sized> From<&'a mut T> for Managed<'a, T> { + fn from(value: &'a mut T) -> Self { Managed::Borrowed(value) } } diff --git a/src/socket/mod.rs b/src/socket/mod.rs index ebc55a6..867398b 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -21,6 +21,8 @@ pub use self::udp::SocketBuffer as UdpSocketBuffer; pub use self::udp::UdpSocket as UdpSocket; pub use self::tcp::SocketBuffer as TcpSocketBuffer; +pub use self::tcp::Incoming as TcpIncoming; +pub use self::tcp::Listener as TcpListener; /// A packet representation. /// @@ -49,6 +51,7 @@ pub trait PacketRepr { /// since the lower layers treat the packet as an opaque octet sequence. pub enum Socket<'a, 'b: 'a> { Udp(UdpSocket<'a, 'b>), + TcpServer(TcpListener<'a>), #[doc(hidden)] __Nonexhaustive } @@ -67,6 +70,8 @@ impl<'a, 'b> Socket<'a, 'b> { match self { &mut Socket::Udp(ref mut socket) => socket.collect(src_addr, dst_addr, protocol, payload), + &mut Socket::TcpServer(ref mut socket) => + socket.collect(src_addr, dst_addr, protocol, payload), &mut Socket::__Nonexhaustive => unreachable!() } } @@ -84,6 +89,8 @@ impl<'a, 'b> Socket<'a, 'b> { match self { &mut Socket::Udp(ref mut socket) => socket.dispatch(f), + &mut Socket::TcpServer(_) => + Err(Error::Exhausted), &mut Socket::__Nonexhaustive => unreachable!() } } @@ -105,3 +112,12 @@ impl<'a, 'b> AsSocket> for Socket<'a, 'b> { } } } + +impl<'a, 'b> AsSocket> for Socket<'a, 'b> { + fn as_socket(&mut self) -> &mut TcpListener<'a> { + match self { + &mut Socket::TcpServer(ref mut socket) => socket, + _ => panic!(".as_socket:: called on wrong socket type") + } + } +} diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 8c69ae2..cf6d59f 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -1,4 +1,8 @@ +use Error; use Managed; +use wire::{IpProtocol, IpAddress, IpEndpoint}; +use wire::{TcpPacket, TcpRepr, TcpControl}; +use socket::{Socket}; /// A TCP stream ring buffer. #[derive(Debug)] @@ -56,6 +60,92 @@ impl<'a> SocketBuffer<'a> { } } +/// A description of incoming TCP connection. +#[derive(Debug)] +pub struct Incoming { + local_end: IpEndpoint, + remote_end: IpEndpoint, + seq_number: u32 +} + +impl Incoming { + /// Return the local endpoint. + pub fn local_end(&self) -> IpEndpoint { + self.local_end + } + + /// Return the remote endpoint. + pub fn remote_end(&self) -> IpEndpoint { + self.remote_end + } +} + +/// A Transmission Control Protocol server socket. +#[derive(Debug)] +pub struct Listener<'a> { + endpoint: IpEndpoint, + backlog: Managed<'a, [Option]>, + accept_at: usize, + length: usize +} + +impl<'a> Listener<'a> { + /// Create a server socket with the given backlog. + pub fn new(endpoint: IpEndpoint, backlog: T) -> Socket<'a, 'static> + where T: Into]>> { + Socket::TcpServer(Listener { + endpoint: endpoint, + backlog: backlog.into(), + accept_at: 0, + length: 0 + }) + } + + /// Accept a connection from this server socket, + pub fn accept(&mut self) -> Option { + if self.length == 0 { return None } + + let accept_at = self.accept_at; + self.accept_at = (self.accept_at + 1) % self.backlog.len(); + self.length -= 1; + + self.backlog[accept_at].take() + } + + /// See [Socket::collect](enum.Socket.html#method.collect). + pub fn collect(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress, + protocol: IpProtocol, payload: &[u8]) + -> Result<(), Error> { + if protocol != IpProtocol::Tcp { return Err(Error::Rejected) } + + let packet = try!(TcpPacket::new(payload)); + let repr = try!(TcpRepr::parse(&packet, src_addr, dst_addr)); + + if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) } + if !self.endpoint.addr.is_unspecified() { + if self.endpoint.addr != *dst_addr { return Err(Error::Rejected) } + } + + match (repr.control, repr.ack_number) { + (TcpControl::Syn, None) => { + if self.length == self.backlog.len() { return Err(Error::Exhausted) } + + let inject_at = (self.accept_at + self.length) % self.backlog.len(); + self.length += 1; + + assert!(self.backlog[inject_at].is_none()); + self.backlog[inject_at] = Some(Incoming { + local_end: IpEndpoint::new(*dst_addr, repr.dst_port), + remote_end: IpEndpoint::new(*src_addr, repr.src_port), + seq_number: repr.seq_number + }); + Ok(()) + } + _ => Err(Error::Rejected) + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 4b3c0f0..9b40a92 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -17,7 +17,7 @@ impl<'a> PacketBuffer<'a> { pub fn new(payload: T) -> PacketBuffer<'a> where T: Into> { PacketBuffer { - endpoint: IpEndpoint::INVALID, + endpoint: IpEndpoint::UNSPECIFIED, size: 0, payload: payload.into() } diff --git a/src/wire/ip.rs b/src/wire/ip.rs index 88b408c..322e6cc 100644 --- a/src/wire/ip.rs +++ b/src/wire/ip.rs @@ -25,9 +25,9 @@ impl fmt::Display for Protocol { /// An internetworking address. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum Address { - /// An invalid address. + /// An unspecified address. /// May be used as a placeholder for storage where the address is not assigned yet. - Invalid, + Unspecified, /// An IPv4 address. Ipv4(Ipv4Address) } @@ -41,23 +41,23 @@ impl Address { /// Query whether the address is a valid unicast address. pub fn is_unicast(&self) -> bool { match self { - &Address::Invalid => false, - &Address::Ipv4(addr) => addr.is_unicast() + &Address::Unspecified => false, + &Address::Ipv4(addr) => addr.is_unicast() } } /// Query whether the address falls into the "unspecified" range. pub fn is_unspecified(&self) -> bool { match self { - &Address::Invalid => false, - &Address::Ipv4(addr) => addr.is_unspecified() + &Address::Unspecified => true, + &Address::Ipv4(addr) => addr.is_unspecified() } } } impl Default for Address { fn default() -> Address { - Address::Invalid + Address::Unspecified } } @@ -70,8 +70,8 @@ impl From for Address { impl fmt::Display for Address { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - &Address::Invalid => write!(f, "(invalid)"), - &Address::Ipv4(addr) => write!(f, "{}", addr) + &Address::Unspecified => write!(f, "*"), + &Address::Ipv4(addr) => write!(f, "{}", addr) } } } @@ -84,7 +84,7 @@ pub struct Endpoint { } impl Endpoint { - pub const INVALID: Endpoint = Endpoint { addr: Address::Invalid, port: 0 }; + pub const UNSPECIFIED: Endpoint = Endpoint { addr: Address::Unspecified, port: 0 }; /// Create an endpoint address from given address and port. pub fn new(addr: Address, port: u16) -> Endpoint {