diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index d72b24c..3df037e 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -1365,19 +1365,34 @@ impl<'a> TcpSocket<'a> { Ok(()) } + fn has_data_or_fin_to_transmit(&self) -> bool { + match self.state { + State::FinWait1 | State::LastAck => true, + _ if !self.tx_buffer.is_empty() => true, + _ => false + } + } + pub(crate) fn poll_at(&self) -> Option { - self.timer.poll_at() - .or_else(|| { - match (self.remote_last_ts, self.timeout) { - (Some(remote_last_ts), Some(timeout)) - if !self.tx_buffer.is_empty() => - Some(remote_last_ts + timeout), - (None, Some(_timeout)) => - Some(0), - (_, _) => - None - } - }) + let timeout_poll_at; + match (self.remote_last_ts, self.timeout) { + // If we're transmitting or retransmitting data, we need to poll at the moment + // when the timeout would expire. + (Some(remote_last_ts), Some(timeout)) if self.has_data_or_fin_to_transmit() => + timeout_poll_at = Some(remote_last_ts + timeout), + // If we're transitioning from a long period of inactivity, and have a timeout set, + // request an invocation of dispatch(); that will update self.remote_last_ts. + (None, Some(_timeout)) => + timeout_poll_at = Some(0), + // Otherwise we have no timeout. + (_, _) => + timeout_poll_at = None + } + + [self.timer.poll_at(), timeout_poll_at] + .iter() + .filter_map(|x| *x) + .min() } } @@ -3192,6 +3207,46 @@ mod test { assert_eq!(s.state, State::Closed); } + #[test] + fn test_fin_wait_1_timeout() { + let mut s = socket_fin_wait_1(); + s.set_timeout(Some(200)); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), Some(200)); + recv!(s, time 400, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_timeout() { + let mut s = socket_last_ack(); + s.set_timeout(Some(200)); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), Some(200)); + recv!(s, time 400, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + // =========================================================================================// // Tests for keep-alive. // =========================================================================================//