From 9d0084171ff5e0cda58094b0e8b7d5bff0b16597 Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 22 Aug 2017 22:32:05 +0000 Subject: [PATCH] Rework responses to TCP packets and factor in RST replies to TcpSocket. --- src/iface/ethernet.rs | 150 ++++++++++++++++++------------------------ src/socket/tcp.rs | 54 ++++++++++++--- src/wire/ip.rs | 17 +++++ src/wire/tcp.rs | 20 ++++-- 4 files changed, 142 insertions(+), 99 deletions(-) diff --git a/src/iface/ethernet.rs b/src/iface/ethernet.rs index f197ab5..8895d91 100644 --- a/src/iface/ethernet.rs +++ b/src/iface/ethernet.rs @@ -27,7 +27,7 @@ enum Response<'a> { Nop, Arp(ArpRepr), Icmpv4(Ipv4Repr, Icmpv4Repr<'a>), - Tcpv4(Ipv4Repr, TcpRepr<'a>) + Tcp(IpRepr, TcpRepr<'a>) } impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { @@ -220,10 +220,10 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { match ipv4_repr.protocol { IpProtocol::Icmp => Self::process_icmpv4(ipv4_repr, ipv4_packet.payload()), - IpProtocol::Tcp => - Self::process_tcpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()), IpProtocol::Udp => Self::process_udpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()), + IpProtocol::Tcp => + Self::process_tcp(sockets, timestamp, ipv4_repr.into(), ipv4_packet.payload()), _ if handled_by_raw_socket => Ok(Response::Nop), _ => { @@ -307,11 +307,9 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { Ok(Response::Icmpv4(ipv4_reply_repr, icmp_reply_repr)) } - fn process_tcpv4<'frame>(sockets: &mut SocketSet, timestamp: u64, - ipv4_repr: Ipv4Repr, ip_payload: &'frame [u8]) -> - Result> { - let ip_repr = IpRepr::Ipv4(ipv4_repr); - + fn process_tcp<'frame>(sockets: &mut SocketSet, timestamp: u64, + ip_repr: IpRepr, ip_payload: &'frame [u8]) -> + Result> { for tcp_socket in sockets.iter_mut().filter_map( >::try_as_socket) { match tcp_socket.process(timestamp, &ip_repr, ip_payload) { @@ -327,99 +325,81 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { // The packet wasn't handled by a socket, send a TCP RST packet. let tcp_packet = TcpPacket::new_checked(ip_payload)?; - if tcp_packet.rst() { - // Don't reply to a TCP RST packet with another TCP RST packet. - return Ok(Response::Nop) + let tcp_repr = TcpRepr::parse(&tcp_packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?; + if tcp_repr.control == TcpControl::Rst { + // Never reply to a TCP RST packet with another TCP RST packet. + Ok(Response::Nop) + } else { + let (ip_reply_repr, tcp_reply_repr) = TcpSocket::rst_reply(&ip_repr, &tcp_repr); + Ok(Response::Tcp(ip_reply_repr, tcp_reply_repr)) } - let tcp_reply_repr = TcpRepr { - src_port: tcp_packet.dst_port(), - dst_port: tcp_packet.src_port(), - control: TcpControl::Rst, - push: false, - seq_number: tcp_packet.ack_number(), - ack_number: Some(tcp_packet.seq_number() + - tcp_packet.segment_len()), - window_len: 0, - max_seg_size: None, - payload: &[] - }; - let ipv4_reply_repr = Ipv4Repr { - src_addr: ipv4_repr.dst_addr, - dst_addr: ipv4_repr.src_addr, - protocol: IpProtocol::Tcp, - payload_len: tcp_reply_repr.buffer_len() - }; - Ok(Response::Tcpv4(ipv4_reply_repr, tcp_reply_repr)) } fn send_response(&mut self, timestamp: u64, response: Response) -> Result<()> { - macro_rules! ip_response { - ($tx_buffer:ident, $frame:ident, $ip_repr:ident) => ({ - let dst_hardware_addr = - match self.arp_cache.lookup(&$ip_repr.dst_addr.into()) { - None => return Err(Error::Unaddressable), - Some(hardware_addr) => hardware_addr - }; + macro_rules! emit_packet { + (Ethernet, $buffer_len:expr, |$frame:ident| $code:stmt) => ({ + let tx_len = EthernetFrame::<&[u8]>::buffer_len($buffer_len); + let mut tx_buffer = self.device.transmit(timestamp, tx_len)?; + debug_assert!(tx_buffer.as_ref().len() == tx_len); - let tx_len = EthernetFrame::<&[u8]>::buffer_len($ip_repr.buffer_len() + - $ip_repr.payload_len); - $tx_buffer = self.device.transmit(timestamp, tx_len)?; - debug_assert!($tx_buffer.as_ref().len() == tx_len); - - $frame = EthernetFrame::new(&mut $tx_buffer); + let mut $frame = EthernetFrame::new(&mut tx_buffer); $frame.set_src_addr(self.hardware_addr); - $frame.set_dst_addr(dst_hardware_addr); - $frame.set_ethertype(EthernetProtocol::Ipv4); - let mut ip_packet = Ipv4Packet::new($frame.payload_mut()); - $ip_repr.emit(&mut ip_packet); - ip_packet + $code + + Ok(()) + }); + + (Ip, $ip_repr:expr, |$payload:ident| $code:stmt) => ({ + let ip_repr = $ip_repr.lower(&self.protocol_addrs)?; + match self.arp_cache.lookup(&ip_repr.dst_addr()) { + None => Err(Error::Unaddressable), + Some(dst_hardware_addr) => { + emit_packet!(Ethernet, ip_repr.total_len(), |frame| { + frame.set_dst_addr(dst_hardware_addr); + match ip_repr { + IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4), + _ => unreachable!() + } + + ip_repr.emit(frame.payload_mut()); + + let $payload = &mut frame.payload_mut()[ip_repr.buffer_len()..]; + $code + }) + } + } }) } match response { - Response::Arp(repr) => { - let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len()); - let mut tx_buffer = self.device.transmit(timestamp, tx_len)?; - debug_assert!(tx_buffer.as_ref().len() == tx_len); + Response::Arp(arp_repr) => { + let dst_hardware_addr = + match arp_repr { + ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr, + _ => unreachable!() + }; - let mut frame = EthernetFrame::new(&mut tx_buffer); - frame.set_src_addr(self.hardware_addr); - frame.set_dst_addr(match repr { - ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr, - _ => unreachable!() - }); - frame.set_ethertype(EthernetProtocol::Arp); + emit_packet!(Ethernet, arp_repr.buffer_len(), |frame| { + frame.set_dst_addr(dst_hardware_addr); + frame.set_ethertype(EthernetProtocol::Arp); - let mut packet = ArpPacket::new(frame.payload_mut()); - repr.emit(&mut packet); - - Ok(()) + let mut packet = ArpPacket::new(frame.payload_mut()); + arp_repr.emit(&mut packet); + }) }, - - Response::Icmpv4(ip_repr, icmp_repr) => { - let mut tx_buffer; - let mut frame; - let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr); - let mut icmp_packet = Icmpv4Packet::new(ip_packet.payload_mut()); - icmp_repr.emit(&mut icmp_packet); - Ok(()) + Response::Icmpv4(ipv4_repr, icmpv4_repr) => { + emit_packet!(Ip, IpRepr::Ipv4(ipv4_repr), |payload| { + icmpv4_repr.emit(&mut Icmpv4Packet::new(payload)); + }) } - - Response::Tcpv4(ip_repr, tcp_repr) => { - let mut tx_buffer; - let mut frame; - let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr); - let mut tcp_packet = TcpPacket::new(ip_packet.payload_mut()); - tcp_repr.emit(&mut tcp_packet, - &IpAddress::Ipv4(ip_repr.src_addr), - &IpAddress::Ipv4(ip_repr.dst_addr)); - Ok(()) - } - - Response::Nop => { - Ok(()) + Response::Tcp(ip_repr, tcp_repr) => { + emit_packet!(Ip, ip_repr, |payload| { + tcp_repr.emit(&mut TcpPacket::new(payload), + &ip_repr.src_addr(), &ip_repr.dst_addr()); + }) } + Response::Nop => Ok(()) } } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index ea48dbb..246f096 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -285,10 +285,10 @@ impl<'a> TcpSocket<'a> { listen_address: IpAddress::default(), local_endpoint: IpEndpoint::default(), remote_endpoint: IpEndpoint::default(), - local_seq_no: TcpSeqNumber(0), - remote_seq_no: TcpSeqNumber(0), - remote_last_seq: TcpSeqNumber(0), - remote_last_ack: TcpSeqNumber(0), + local_seq_no: TcpSeqNumber::default(), + remote_seq_no: TcpSeqNumber::default(), + remote_last_seq: TcpSeqNumber::default(), + remote_last_ack: TcpSeqNumber::default(), remote_win_len: 0, remote_mss: DEFAULT_MSS, retransmit: Retransmit::new(), @@ -335,10 +335,10 @@ impl<'a> TcpSocket<'a> { self.listen_address = IpAddress::default(); self.local_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default(); - self.local_seq_no = TcpSeqNumber(0); - self.remote_seq_no = TcpSeqNumber(0); - self.remote_last_seq = TcpSeqNumber(0); - self.remote_last_ack = TcpSeqNumber(0); + self.local_seq_no = TcpSeqNumber::default(); + self.remote_seq_no = TcpSeqNumber::default(); + self.remote_last_seq = TcpSeqNumber::default(); + self.remote_last_ack = TcpSeqNumber::default(); self.remote_win_len = 0; self.remote_mss = DEFAULT_MSS; self.retransmit.reset(); @@ -681,6 +681,44 @@ impl<'a> TcpSocket<'a> { self.state = state } + pub(crate) fn reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) { + let tcp_reply_repr = TcpRepr { + src_port: tcp_repr.dst_port, + dst_port: tcp_repr.src_port, + control: TcpControl::None, + push: false, + seq_number: TcpSeqNumber(0), + ack_number: None, + window_len: 0, + max_seg_size: None, + payload: &[] + }; + let ip_reply_repr = IpRepr::Unspecified { + src_addr: ip_repr.dst_addr(), + dst_addr: ip_repr.src_addr(), + protocol: IpProtocol::Tcp, + payload_len: tcp_reply_repr.buffer_len() + }; + (ip_reply_repr, tcp_reply_repr) + } + + pub(crate) fn rst_reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) { + debug_assert!(tcp_repr.control != TcpControl::Rst); + + let (ip_reply_repr, mut tcp_reply_repr) = Self::reply(ip_repr, tcp_repr); + + // See https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for explanation + // of why we sometimes send an RST and sometimes an RST|ACK + tcp_reply_repr.control = TcpControl::Rst; + tcp_reply_repr.seq_number = tcp_repr.ack_number.unwrap_or_default(); + if tcp_repr.control == TcpControl::Syn { + tcp_reply_repr.ack_number = Some(tcp_repr.seq_number + + tcp_repr.segment_len()); + } + + (ip_reply_repr, tcp_reply_repr) + } + pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, payload: &[u8]) -> Result<()> { debug_assert!(ip_repr.protocol() == IpProtocol::Tcp); diff --git a/src/wire/ip.rs b/src/wire/ip.rs index 4e6fdfd..a091114 100644 --- a/src/wire/ip.rs +++ b/src/wire/ip.rs @@ -177,6 +177,12 @@ pub enum IpRepr { __Nonexhaustive } +impl From for IpRepr { + fn from(repr: Ipv4Repr) -> IpRepr { + IpRepr::Ipv4(repr) + } +} + impl IpRepr { /// Return the protocol version. pub fn version(&self) -> Version { @@ -323,6 +329,17 @@ impl IpRepr { unreachable!() } } + + /// Return the total length of a packet that will be emitted from this + /// high-level representation. + /// + /// This is the same as `repr.buffer_len() + repr.payload_len()`. + /// + /// # Panics + /// This function panics if invoked on an unspecified representation. + pub fn total_len(&self) -> usize { + self.buffer_len() + self.payload_len() + } } pub mod checksum { diff --git a/src/wire/tcp.rs b/src/wire/tcp.rs index 71b8be5..ef46dcb 100644 --- a/src/wire/tcp.rs +++ b/src/wire/tcp.rs @@ -9,7 +9,7 @@ use super::ip::checksum; /// /// A sequence number is a monotonically advancing integer modulo 232. /// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] pub struct SeqNumber(pub i32); impl fmt::Display for SeqNumber { @@ -275,7 +275,6 @@ impl> Packet { } /// Return the length of the segment, in terms of sequence space. - #[inline] pub fn segment_len(&self) -> usize { let data = self.buffer.as_ref(); let mut length = data.len() - self.header_len() as usize; @@ -695,10 +694,9 @@ impl<'a> Repr<'a> { } /// Emit a high-level representation into a Transmission Control Protocol packet. - pub fn emit(&self, packet: &mut Packet<&mut T>, - src_addr: &IpAddress, - dst_addr: &IpAddress) - where T: AsRef<[u8]> + AsMut<[u8]> { + pub fn emit(&self, packet: &mut Packet<&mut T>, + src_addr: &IpAddress, dst_addr: &IpAddress) + where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { packet.set_src_port(self.src_port); packet.set_dst_port(self.dst_port); packet.set_seq_number(self.seq_number); @@ -727,6 +725,16 @@ impl<'a> Repr<'a> { packet.payload_mut().copy_from_slice(self.payload); packet.fill_checksum(src_addr, dst_addr) } + + /// Return the length of the segment, in terms of sequence space. + pub fn segment_len(&self) -> usize { + let mut length = self.payload.len(); + match self.control { + Control::Syn | Control::Fin => length += 1, + _ => () + } + length + } } impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {