Completely redo the logic of TCP socket polling.

The previous implementation made no sense. It is obvious that
poll_at() should use the same mechanisms to decide whether dispatch()
should be called as dispatch() itself uses to decide whether to send
anything.

This fixes numerous busy looping issues that arise if the return
value of poll() is used for waiting.
This commit is contained in:
whitequark 2017-09-24 11:26:51 +00:00
parent 96b284a30f
commit 32d720831a
1 changed files with 114 additions and 53 deletions

View File

@ -996,6 +996,7 @@ impl<'a> TcpSocket<'a> {
self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port);
self.remote_seq_no = repr.seq_number + 1; self.remote_seq_no = repr.seq_number + 1;
self.remote_last_seq = self.local_seq_no + 1; self.remote_last_seq = self.local_seq_no + 1;
self.remote_last_ack = Some(repr.seq_number);
if let Some(max_seg_size) = repr.max_seg_size { if let Some(max_seg_size) = repr.max_seg_size {
self.remote_mss = max_seg_size as usize; self.remote_mss = max_seg_size as usize;
} }
@ -1063,6 +1064,7 @@ impl<'a> TcpSocket<'a> {
(State::LastAck, TcpControl::None) => { (State::LastAck, TcpControl::None) => {
// Clear the remote endpoint, or we'll send an RST there. // Clear the remote endpoint, or we'll send an RST there.
self.set_state(State::Closed); self.set_state(State::Closed);
self.local_endpoint = IpEndpoint::default();
self.remote_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default();
} }
@ -1154,18 +1156,35 @@ impl<'a> TcpSocket<'a> {
} }
} }
fn seq_to_transmit(&self, control: TcpControl) -> bool { fn seq_to_transmit(&self) -> bool {
self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len() let control;
match self.state {
State::SynSent | State::SynReceived =>
control = TcpControl::Syn,
State::FinWait1 | State::LastAck =>
control = TcpControl::Fin,
_ => control = TcpControl::None
}
if self.remote_win_len > 0 {
self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len()
} else {
false
}
} }
fn ack_to_transmit(&self) -> bool { fn ack_to_transmit(&self) -> bool {
if let Some(remote_last_ack) = self.remote_last_ack { if let Some(remote_last_ack) = self.remote_last_ack {
remote_last_ack < self.remote_seq_no + self.rx_buffer.len() remote_last_ack < self.remote_seq_no + self.rx_buffer.len()
} else { } else {
true false
} }
} }
fn window_to_update(&self) -> bool {
self.rx_buffer.window() as u16 > self.remote_last_win
}
pub(crate) fn dispatch<F>(&mut self, timestamp: u64, limits: &DeviceLimits, pub(crate) fn dispatch<F>(&mut self, timestamp: u64, limits: &DeviceLimits,
emit: F) -> Result<()> emit: F) -> Result<()>
where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> { where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> {
@ -1182,12 +1201,13 @@ impl<'a> TcpSocket<'a> {
self.remote_last_ts = Some(timestamp); self.remote_last_ts = Some(timestamp);
} }
// Check if any state needs to be changed because of a timer.
if self.timed_out(timestamp) { if self.timed_out(timestamp) {
// If a timeout expires, we should abort the connection. // If a timeout expires, we should abort the connection.
net_debug!("[{}]{}:{}: timeout exceeded", net_debug!("[{}]{}:{}: timeout exceeded",
self.debug_id, self.local_endpoint, self.remote_endpoint); self.debug_id, self.local_endpoint, self.remote_endpoint);
self.set_state(State::Closed); self.set_state(State::Closed);
} else if !self.seq_to_transmit(TcpControl::None) { } else if !self.seq_to_transmit() {
if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) { if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) {
// If a retransmit timer expired, we should resend data starting at the last ACK. // If a retransmit timer expired, we should resend data starting at the last ACK.
net_debug!("[{}]{}:{}: retransmitting at t+{}ms", net_debug!("[{}]{}:{}: retransmitting at t+{}ms",
@ -1197,6 +1217,25 @@ impl<'a> TcpSocket<'a> {
} }
} }
// Decide whether we're sending a packet.
if self.seq_to_transmit() {
// If we have data to transmit and it fits into partner's window, do it.
} else if self.ack_to_transmit() {
// If we have data to acknowledge, do it.
} else if self.window_to_update() {
// If we have window length increase to advertise, do it.
} else if self.state == State::Closed {
// If we need to abort the connection, do it.
} else if self.timer.should_retransmit(timestamp).is_some() {
// If we have packets to retransmit, do it.
} else if self.timer.should_keep_alive(timestamp) {
// If we need to transmit a keep-alive packet, do it.
} else if self.timer.should_close(timestamp) {
// If we have spent enough time in the TIME-WAIT state, close the socket.
} else {
return Err(Error::Exhausted)
}
// Construct the lowered IP representation. // Construct the lowered IP representation.
// We might need this to calculate the MSS, so do it early. // We might need this to calculate the MSS, so do it early.
let mut ip_repr = IpRepr::Unspecified { let mut ip_repr = IpRepr::Unspecified {
@ -1279,22 +1318,6 @@ impl<'a> TcpSocket<'a> {
} }
} }
if self.seq_to_transmit(repr.control) && repr.segment_len() > 0 {
// If we have data to transmit and it fits into partner's window, do it.
} else if self.ack_to_transmit() && repr.ack_number.is_some() {
// If we have data to acknowledge, do it.
} else if repr.window_len > self.remote_last_win {
// If we have window length increase to advertise, do it.
} else if self.timer.should_retransmit(timestamp).is_some() {
// If we have packets to retransmit, do it.
} else if self.timer.should_keep_alive(timestamp) {
// If we need to transmit a keep-alive packet, do it.
} else if repr.control == TcpControl::Rst {
// If we need to abort the connection, do it.
} else {
return Err(Error::Exhausted)
}
if repr.payload.len() > 0 { if repr.payload.len() > 0 {
net_trace!("[{}]{}:{}: tx buffer: reading {} octets at offset {}", net_trace!("[{}]{}:{}: tx buffer: reading {} octets at offset {}",
self.debug_id, self.local_endpoint, self.remote_endpoint, self.debug_id, self.local_endpoint, self.remote_endpoint,
@ -1315,6 +1338,10 @@ impl<'a> TcpSocket<'a> {
flags); flags);
} }
// There might be more than one reason to send a packet. E.g. the keep-alive timer
// has expired, and we also have data in transmit buffer. Since any packet that occupies
// sequence space will elicit an ACK, we only need to send an explicit packet if we
// couldn't fill the sequence space with anything.
let is_keep_alive; let is_keep_alive;
if self.timer.should_keep_alive(timestamp) && repr.is_empty() { if self.timer.should_keep_alive(timestamp) && repr.is_empty() {
net_trace!("[{}]{}:{}: sending a keep-alive", net_trace!("[{}]{}:{}: sending a keep-alive",
@ -1327,13 +1354,20 @@ impl<'a> TcpSocket<'a> {
} }
if repr.control == TcpControl::Syn { if repr.control == TcpControl::Syn {
// See RFC 6691 for an explanation of this calculation. // Fill the MSS option. See RFC 6691 for an explanation of this calculation.
let mut max_segment_size = limits.max_transmission_unit; let mut max_segment_size = limits.max_transmission_unit;
max_segment_size -= ip_repr.buffer_len(); max_segment_size -= ip_repr.buffer_len();
max_segment_size -= repr.header_len(); max_segment_size -= repr.header_len();
repr.max_seg_size = Some(max_segment_size as u16); repr.max_seg_size = Some(max_segment_size as u16);
} }
// Actually send the packet. If this succeeds, it means the packet is in
// the device buffer, and its transmission is imminent. If not, we might have
// a number of problems, e.g. we need neighbor discovery.
//
// Bailing out if the packet isn't placed in the device buffer allows us
// to not waste time waiting for the retransmit timer on packets that we know
// for sure will not be successfully transmitted.
ip_repr.set_payload_len(repr.buffer_len()); ip_repr.set_payload_len(repr.buffer_len());
emit((ip_repr, repr))?; emit((ip_repr, repr))?;
@ -1341,7 +1375,8 @@ impl<'a> TcpSocket<'a> {
// the keep-alive timer. // the keep-alive timer.
self.timer.rewind_keep_alive(timestamp, self.keep_alive); self.timer.rewind_keep_alive(timestamp, self.keep_alive);
// Leave the rest of the state intact if sending a keep-alive packet. // Leave the rest of the state intact if sending a keep-alive packet, since those
// carry a fake segment.
if is_keep_alive { return Ok(()) } if is_keep_alive { return Ok(()) }
// We've sent a packet successfully, so we can update the internal state now. // We've sent a packet successfully, so we can update the internal state now.
@ -1349,15 +1384,14 @@ impl<'a> TcpSocket<'a> {
self.remote_last_ack = repr.ack_number; self.remote_last_ack = repr.ack_number;
self.remote_last_win = repr.window_len; self.remote_last_win = repr.window_len;
if !self.seq_to_transmit(repr.control) && repr.segment_len() > 0 { if !self.seq_to_transmit() && repr.segment_len() > 0 {
// If we've transmitted all data we could (and there was something at all, // If we've transmitted all data we could (and there was something at all,
// data or flag, to transmit, not just an ACK), wind up the retransmit timer. // data or flag, to transmit, not just an ACK), wind up the retransmit timer.
self.timer.set_for_retransmit(timestamp); self.timer.set_for_retransmit(timestamp);
} }
if repr.control == TcpControl::Rst { if self.state == State::Closed {
// When aborting a connection, forget about it after sending // When aborting a connection, forget about it after sending a single RST packet.
// the RST packet once.
self.local_endpoint = IpEndpoint::default(); self.local_endpoint = IpEndpoint::default();
self.remote_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default();
} }
@ -1365,34 +1399,38 @@ impl<'a> TcpSocket<'a> {
Ok(()) Ok(())
} }
fn has_data_or_fin_to_transmit(&self) -> bool {
match self.state {
State::FinWait1 | State::LastAck => true,
_ if !self.tx_buffer.is_empty() => true,
_ => false
}
}
pub(crate) fn poll_at(&self) -> Option<u64> { pub(crate) fn poll_at(&self) -> Option<u64> {
let timeout_poll_at; // The logic here mirrors the beginning of dispatch() closely.
match (self.remote_last_ts, self.timeout) { if !self.remote_endpoint.is_specified() {
// If we're transmitting or retransmitting data, we need to poll at the moment // No one to talk to, nothing to transmit.
// when the timeout would expire. None
(Some(remote_last_ts), Some(timeout)) if self.has_data_or_fin_to_transmit() => } else if self.remote_last_ts.is_none() {
timeout_poll_at = Some(remote_last_ts + timeout), // Socket stopped being quiet recently, we need to acquire a timestamp.
// If we're transitioning from a long period of inactivity, and have a timeout set, Some(0)
// request an invocation of dispatch(); that will update self.remote_last_ts. } else if self.state == State::Closed {
(None, Some(_timeout)) => // Socket was aborted, we have an RST packet to transmit.
timeout_poll_at = Some(0), Some(0)
// Otherwise we have no timeout. } else if self.seq_to_transmit() || self.ack_to_transmit() || self.window_to_update() {
(_, _) => // We have a data or flag packet to transmit.
timeout_poll_at = None Some(0)
} } else {
let timeout_poll_at;
match (self.remote_last_ts, self.timeout) {
// If we're transmitting or retransmitting data, we need to poll at the moment
// when the timeout would expire.
(Some(remote_last_ts), Some(timeout)) =>
timeout_poll_at = Some(remote_last_ts + timeout),
// Otherwise we have no timeout.
(_, _) =>
timeout_poll_at = None
}
[self.timer.poll_at(), timeout_poll_at] // We wait for the earliest of our timers to fire.
.iter() [self.timer.poll_at(), timeout_poll_at]
.filter_map(|x| *x) .iter()
.min() .filter_map(|x| *x)
.min()
}
} }
} }
@ -3114,6 +3152,13 @@ mod test {
// Tests for timeouts. // Tests for timeouts.
// =========================================================================================// // =========================================================================================//
#[test]
fn test_listen_timeout() {
let mut s = socket_listen();
s.set_timeout(Some(100));
assert_eq!(s.poll_at(), None);
}
#[test] #[test]
fn test_connect_timeout() { fn test_connect_timeout() {
let mut s = socket(); let mut s = socket();
@ -3143,7 +3188,7 @@ mod test {
let mut s = socket_established(); let mut s = socket_established();
s.set_timeout(Some(200)); s.set_timeout(Some(200));
recv!(s, time 250, Err(Error::Exhausted)); recv!(s, time 250, Err(Error::Exhausted));
assert_eq!(s.poll_at(), None); assert_eq!(s.poll_at(), Some(450));
s.send_slice(b"abcdef").unwrap(); s.send_slice(b"abcdef").unwrap();
assert_eq!(s.poll_at(), Some(0)); assert_eq!(s.poll_at(), Some(0));
recv!(s, time 255, Ok(TcpRepr { recv!(s, time 255, Ok(TcpRepr {
@ -3247,6 +3292,22 @@ mod test {
assert_eq!(s.state, State::Closed); assert_eq!(s.state, State::Closed);
} }
#[test]
fn test_closed_timeout() {
let mut s = socket_established();
s.set_timeout(Some(200));
s.remote_last_ts = Some(100);
s.abort();
assert_eq!(s.poll_at(), Some(0));
recv!(s, time 100, Ok(TcpRepr {
control: TcpControl::Rst,
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1),
..RECV_TEMPL
}));
assert_eq!(s.poll_at(), None);
}
// =========================================================================================// // =========================================================================================//
// Tests for keep-alive. // Tests for keep-alive.
// =========================================================================================// // =========================================================================================//