Rework responses to TCP packets and factor in RST replies to TcpSocket.
This commit is contained in:
parent
7c9a072dd2
commit
9d0084171f
|
@ -27,7 +27,7 @@ enum Response<'a> {
|
|||
Nop,
|
||||
Arp(ArpRepr),
|
||||
Icmpv4(Ipv4Repr, Icmpv4Repr<'a>),
|
||||
Tcpv4(Ipv4Repr, TcpRepr<'a>)
|
||||
Tcp(IpRepr, TcpRepr<'a>)
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
|
||||
|
@ -220,10 +220,10 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
|
|||
match ipv4_repr.protocol {
|
||||
IpProtocol::Icmp =>
|
||||
Self::process_icmpv4(ipv4_repr, ipv4_packet.payload()),
|
||||
IpProtocol::Tcp =>
|
||||
Self::process_tcpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()),
|
||||
IpProtocol::Udp =>
|
||||
Self::process_udpv4(sockets, timestamp, ipv4_repr, ipv4_packet.payload()),
|
||||
IpProtocol::Tcp =>
|
||||
Self::process_tcp(sockets, timestamp, ipv4_repr.into(), ipv4_packet.payload()),
|
||||
_ if handled_by_raw_socket =>
|
||||
Ok(Response::Nop),
|
||||
_ => {
|
||||
|
@ -307,11 +307,9 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
|
|||
Ok(Response::Icmpv4(ipv4_reply_repr, icmp_reply_repr))
|
||||
}
|
||||
|
||||
fn process_tcpv4<'frame>(sockets: &mut SocketSet, timestamp: u64,
|
||||
ipv4_repr: Ipv4Repr, ip_payload: &'frame [u8]) ->
|
||||
Result<Response<'frame>> {
|
||||
let ip_repr = IpRepr::Ipv4(ipv4_repr);
|
||||
|
||||
fn process_tcp<'frame>(sockets: &mut SocketSet, timestamp: u64,
|
||||
ip_repr: IpRepr, ip_payload: &'frame [u8]) ->
|
||||
Result<Response<'frame>> {
|
||||
for tcp_socket in sockets.iter_mut().filter_map(
|
||||
<Socket as AsSocket<TcpSocket>>::try_as_socket) {
|
||||
match tcp_socket.process(timestamp, &ip_repr, ip_payload) {
|
||||
|
@ -327,99 +325,81 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
|
|||
|
||||
// The packet wasn't handled by a socket, send a TCP RST packet.
|
||||
let tcp_packet = TcpPacket::new_checked(ip_payload)?;
|
||||
if tcp_packet.rst() {
|
||||
// Don't reply to a TCP RST packet with another TCP RST packet.
|
||||
return Ok(Response::Nop)
|
||||
let tcp_repr = TcpRepr::parse(&tcp_packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
|
||||
if tcp_repr.control == TcpControl::Rst {
|
||||
// Never reply to a TCP RST packet with another TCP RST packet.
|
||||
Ok(Response::Nop)
|
||||
} else {
|
||||
let (ip_reply_repr, tcp_reply_repr) = TcpSocket::rst_reply(&ip_repr, &tcp_repr);
|
||||
Ok(Response::Tcp(ip_reply_repr, tcp_reply_repr))
|
||||
}
|
||||
let tcp_reply_repr = TcpRepr {
|
||||
src_port: tcp_packet.dst_port(),
|
||||
dst_port: tcp_packet.src_port(),
|
||||
control: TcpControl::Rst,
|
||||
push: false,
|
||||
seq_number: tcp_packet.ack_number(),
|
||||
ack_number: Some(tcp_packet.seq_number() +
|
||||
tcp_packet.segment_len()),
|
||||
window_len: 0,
|
||||
max_seg_size: None,
|
||||
payload: &[]
|
||||
};
|
||||
let ipv4_reply_repr = Ipv4Repr {
|
||||
src_addr: ipv4_repr.dst_addr,
|
||||
dst_addr: ipv4_repr.src_addr,
|
||||
protocol: IpProtocol::Tcp,
|
||||
payload_len: tcp_reply_repr.buffer_len()
|
||||
};
|
||||
Ok(Response::Tcpv4(ipv4_reply_repr, tcp_reply_repr))
|
||||
}
|
||||
|
||||
fn send_response(&mut self, timestamp: u64, response: Response) -> Result<()> {
|
||||
macro_rules! ip_response {
|
||||
($tx_buffer:ident, $frame:ident, $ip_repr:ident) => ({
|
||||
let dst_hardware_addr =
|
||||
match self.arp_cache.lookup(&$ip_repr.dst_addr.into()) {
|
||||
None => return Err(Error::Unaddressable),
|
||||
Some(hardware_addr) => hardware_addr
|
||||
};
|
||||
macro_rules! emit_packet {
|
||||
(Ethernet, $buffer_len:expr, |$frame:ident| $code:stmt) => ({
|
||||
let tx_len = EthernetFrame::<&[u8]>::buffer_len($buffer_len);
|
||||
let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
|
||||
debug_assert!(tx_buffer.as_ref().len() == tx_len);
|
||||
|
||||
let tx_len = EthernetFrame::<&[u8]>::buffer_len($ip_repr.buffer_len() +
|
||||
$ip_repr.payload_len);
|
||||
$tx_buffer = self.device.transmit(timestamp, tx_len)?;
|
||||
debug_assert!($tx_buffer.as_ref().len() == tx_len);
|
||||
|
||||
$frame = EthernetFrame::new(&mut $tx_buffer);
|
||||
let mut $frame = EthernetFrame::new(&mut tx_buffer);
|
||||
$frame.set_src_addr(self.hardware_addr);
|
||||
$frame.set_dst_addr(dst_hardware_addr);
|
||||
$frame.set_ethertype(EthernetProtocol::Ipv4);
|
||||
|
||||
let mut ip_packet = Ipv4Packet::new($frame.payload_mut());
|
||||
$ip_repr.emit(&mut ip_packet);
|
||||
ip_packet
|
||||
$code
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
(Ip, $ip_repr:expr, |$payload:ident| $code:stmt) => ({
|
||||
let ip_repr = $ip_repr.lower(&self.protocol_addrs)?;
|
||||
match self.arp_cache.lookup(&ip_repr.dst_addr()) {
|
||||
None => Err(Error::Unaddressable),
|
||||
Some(dst_hardware_addr) => {
|
||||
emit_packet!(Ethernet, ip_repr.total_len(), |frame| {
|
||||
frame.set_dst_addr(dst_hardware_addr);
|
||||
match ip_repr {
|
||||
IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4),
|
||||
_ => unreachable!()
|
||||
}
|
||||
|
||||
ip_repr.emit(frame.payload_mut());
|
||||
|
||||
let $payload = &mut frame.payload_mut()[ip_repr.buffer_len()..];
|
||||
$code
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
match response {
|
||||
Response::Arp(repr) => {
|
||||
let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len());
|
||||
let mut tx_buffer = self.device.transmit(timestamp, tx_len)?;
|
||||
debug_assert!(tx_buffer.as_ref().len() == tx_len);
|
||||
Response::Arp(arp_repr) => {
|
||||
let dst_hardware_addr =
|
||||
match arp_repr {
|
||||
ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr,
|
||||
_ => unreachable!()
|
||||
};
|
||||
|
||||
let mut frame = EthernetFrame::new(&mut tx_buffer);
|
||||
frame.set_src_addr(self.hardware_addr);
|
||||
frame.set_dst_addr(match repr {
|
||||
ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr,
|
||||
_ => unreachable!()
|
||||
});
|
||||
frame.set_ethertype(EthernetProtocol::Arp);
|
||||
emit_packet!(Ethernet, arp_repr.buffer_len(), |frame| {
|
||||
frame.set_dst_addr(dst_hardware_addr);
|
||||
frame.set_ethertype(EthernetProtocol::Arp);
|
||||
|
||||
let mut packet = ArpPacket::new(frame.payload_mut());
|
||||
repr.emit(&mut packet);
|
||||
|
||||
Ok(())
|
||||
let mut packet = ArpPacket::new(frame.payload_mut());
|
||||
arp_repr.emit(&mut packet);
|
||||
})
|
||||
},
|
||||
|
||||
Response::Icmpv4(ip_repr, icmp_repr) => {
|
||||
let mut tx_buffer;
|
||||
let mut frame;
|
||||
let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr);
|
||||
let mut icmp_packet = Icmpv4Packet::new(ip_packet.payload_mut());
|
||||
icmp_repr.emit(&mut icmp_packet);
|
||||
Ok(())
|
||||
Response::Icmpv4(ipv4_repr, icmpv4_repr) => {
|
||||
emit_packet!(Ip, IpRepr::Ipv4(ipv4_repr), |payload| {
|
||||
icmpv4_repr.emit(&mut Icmpv4Packet::new(payload));
|
||||
})
|
||||
}
|
||||
|
||||
Response::Tcpv4(ip_repr, tcp_repr) => {
|
||||
let mut tx_buffer;
|
||||
let mut frame;
|
||||
let mut ip_packet = ip_response!(tx_buffer, frame, ip_repr);
|
||||
let mut tcp_packet = TcpPacket::new(ip_packet.payload_mut());
|
||||
tcp_repr.emit(&mut tcp_packet,
|
||||
&IpAddress::Ipv4(ip_repr.src_addr),
|
||||
&IpAddress::Ipv4(ip_repr.dst_addr));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Response::Nop => {
|
||||
Ok(())
|
||||
Response::Tcp(ip_repr, tcp_repr) => {
|
||||
emit_packet!(Ip, ip_repr, |payload| {
|
||||
tcp_repr.emit(&mut TcpPacket::new(payload),
|
||||
&ip_repr.src_addr(), &ip_repr.dst_addr());
|
||||
})
|
||||
}
|
||||
Response::Nop => Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -285,10 +285,10 @@ impl<'a> TcpSocket<'a> {
|
|||
listen_address: IpAddress::default(),
|
||||
local_endpoint: IpEndpoint::default(),
|
||||
remote_endpoint: IpEndpoint::default(),
|
||||
local_seq_no: TcpSeqNumber(0),
|
||||
remote_seq_no: TcpSeqNumber(0),
|
||||
remote_last_seq: TcpSeqNumber(0),
|
||||
remote_last_ack: TcpSeqNumber(0),
|
||||
local_seq_no: TcpSeqNumber::default(),
|
||||
remote_seq_no: TcpSeqNumber::default(),
|
||||
remote_last_seq: TcpSeqNumber::default(),
|
||||
remote_last_ack: TcpSeqNumber::default(),
|
||||
remote_win_len: 0,
|
||||
remote_mss: DEFAULT_MSS,
|
||||
retransmit: Retransmit::new(),
|
||||
|
@ -335,10 +335,10 @@ impl<'a> TcpSocket<'a> {
|
|||
self.listen_address = IpAddress::default();
|
||||
self.local_endpoint = IpEndpoint::default();
|
||||
self.remote_endpoint = IpEndpoint::default();
|
||||
self.local_seq_no = TcpSeqNumber(0);
|
||||
self.remote_seq_no = TcpSeqNumber(0);
|
||||
self.remote_last_seq = TcpSeqNumber(0);
|
||||
self.remote_last_ack = TcpSeqNumber(0);
|
||||
self.local_seq_no = TcpSeqNumber::default();
|
||||
self.remote_seq_no = TcpSeqNumber::default();
|
||||
self.remote_last_seq = TcpSeqNumber::default();
|
||||
self.remote_last_ack = TcpSeqNumber::default();
|
||||
self.remote_win_len = 0;
|
||||
self.remote_mss = DEFAULT_MSS;
|
||||
self.retransmit.reset();
|
||||
|
@ -681,6 +681,44 @@ impl<'a> TcpSocket<'a> {
|
|||
self.state = state
|
||||
}
|
||||
|
||||
pub(crate) fn reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) {
|
||||
let tcp_reply_repr = TcpRepr {
|
||||
src_port: tcp_repr.dst_port,
|
||||
dst_port: tcp_repr.src_port,
|
||||
control: TcpControl::None,
|
||||
push: false,
|
||||
seq_number: TcpSeqNumber(0),
|
||||
ack_number: None,
|
||||
window_len: 0,
|
||||
max_seg_size: None,
|
||||
payload: &[]
|
||||
};
|
||||
let ip_reply_repr = IpRepr::Unspecified {
|
||||
src_addr: ip_repr.dst_addr(),
|
||||
dst_addr: ip_repr.src_addr(),
|
||||
protocol: IpProtocol::Tcp,
|
||||
payload_len: tcp_reply_repr.buffer_len()
|
||||
};
|
||||
(ip_reply_repr, tcp_reply_repr)
|
||||
}
|
||||
|
||||
pub(crate) fn rst_reply(ip_repr: &IpRepr, tcp_repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) {
|
||||
debug_assert!(tcp_repr.control != TcpControl::Rst);
|
||||
|
||||
let (ip_reply_repr, mut tcp_reply_repr) = Self::reply(ip_repr, tcp_repr);
|
||||
|
||||
// See https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for explanation
|
||||
// of why we sometimes send an RST and sometimes an RST|ACK
|
||||
tcp_reply_repr.control = TcpControl::Rst;
|
||||
tcp_reply_repr.seq_number = tcp_repr.ack_number.unwrap_or_default();
|
||||
if tcp_repr.control == TcpControl::Syn {
|
||||
tcp_reply_repr.ack_number = Some(tcp_repr.seq_number +
|
||||
tcp_repr.segment_len());
|
||||
}
|
||||
|
||||
(ip_reply_repr, tcp_reply_repr)
|
||||
}
|
||||
|
||||
pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr,
|
||||
payload: &[u8]) -> Result<()> {
|
||||
debug_assert!(ip_repr.protocol() == IpProtocol::Tcp);
|
||||
|
|
|
@ -177,6 +177,12 @@ pub enum IpRepr {
|
|||
__Nonexhaustive
|
||||
}
|
||||
|
||||
impl From<Ipv4Repr> for IpRepr {
|
||||
fn from(repr: Ipv4Repr) -> IpRepr {
|
||||
IpRepr::Ipv4(repr)
|
||||
}
|
||||
}
|
||||
|
||||
impl IpRepr {
|
||||
/// Return the protocol version.
|
||||
pub fn version(&self) -> Version {
|
||||
|
@ -323,6 +329,17 @@ impl IpRepr {
|
|||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the total length of a packet that will be emitted from this
|
||||
/// high-level representation.
|
||||
///
|
||||
/// This is the same as `repr.buffer_len() + repr.payload_len()`.
|
||||
///
|
||||
/// # Panics
|
||||
/// This function panics if invoked on an unspecified representation.
|
||||
pub fn total_len(&self) -> usize {
|
||||
self.buffer_len() + self.payload_len()
|
||||
}
|
||||
}
|
||||
|
||||
pub mod checksum {
|
||||
|
|
|
@ -9,7 +9,7 @@ use super::ip::checksum;
|
|||
///
|
||||
/// A sequence number is a monotonically advancing integer modulo 2<sup>32</sup>.
|
||||
/// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
|
||||
pub struct SeqNumber(pub i32);
|
||||
|
||||
impl fmt::Display for SeqNumber {
|
||||
|
@ -275,7 +275,6 @@ impl<T: AsRef<[u8]>> Packet<T> {
|
|||
}
|
||||
|
||||
/// Return the length of the segment, in terms of sequence space.
|
||||
#[inline]
|
||||
pub fn segment_len(&self) -> usize {
|
||||
let data = self.buffer.as_ref();
|
||||
let mut length = data.len() - self.header_len() as usize;
|
||||
|
@ -695,10 +694,9 @@ impl<'a> Repr<'a> {
|
|||
}
|
||||
|
||||
/// Emit a high-level representation into a Transmission Control Protocol packet.
|
||||
pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>,
|
||||
src_addr: &IpAddress,
|
||||
dst_addr: &IpAddress)
|
||||
where T: AsRef<[u8]> + AsMut<[u8]> {
|
||||
pub fn emit<T>(&self, packet: &mut Packet<&mut T>,
|
||||
src_addr: &IpAddress, dst_addr: &IpAddress)
|
||||
where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized {
|
||||
packet.set_src_port(self.src_port);
|
||||
packet.set_dst_port(self.dst_port);
|
||||
packet.set_seq_number(self.seq_number);
|
||||
|
@ -727,6 +725,16 @@ impl<'a> Repr<'a> {
|
|||
packet.payload_mut().copy_from_slice(self.payload);
|
||||
packet.fill_checksum(src_addr, dst_addr)
|
||||
}
|
||||
|
||||
/// Return the length of the segment, in terms of sequence space.
|
||||
pub fn segment_len(&self) -> usize {
|
||||
let mut length = self.payload.len();
|
||||
match self.control {
|
||||
Control::Syn | Control::Fin => length += 1,
|
||||
_ => ()
|
||||
}
|
||||
length
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
|
||||
|
|
Loading…
Reference in New Issue