diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 3df037e..016bff1 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -996,6 +996,7 @@ impl<'a> TcpSocket<'a> { self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); self.remote_seq_no = repr.seq_number + 1; self.remote_last_seq = self.local_seq_no + 1; + self.remote_last_ack = Some(repr.seq_number); if let Some(max_seg_size) = repr.max_seg_size { self.remote_mss = max_seg_size as usize; } @@ -1063,6 +1064,7 @@ impl<'a> TcpSocket<'a> { (State::LastAck, TcpControl::None) => { // Clear the remote endpoint, or we'll send an RST there. self.set_state(State::Closed); + self.local_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default(); } @@ -1154,18 +1156,35 @@ impl<'a> TcpSocket<'a> { } } - fn seq_to_transmit(&self, control: TcpControl) -> bool { - self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len() + fn seq_to_transmit(&self) -> bool { + let control; + match self.state { + State::SynSent | State::SynReceived => + control = TcpControl::Syn, + State::FinWait1 | State::LastAck => + control = TcpControl::Fin, + _ => control = TcpControl::None + } + + if self.remote_win_len > 0 { + self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len() + } else { + false + } } fn ack_to_transmit(&self) -> bool { if let Some(remote_last_ack) = self.remote_last_ack { remote_last_ack < self.remote_seq_no + self.rx_buffer.len() } else { - true + false } } + fn window_to_update(&self) -> bool { + self.rx_buffer.window() as u16 > self.remote_last_win + } + pub(crate) fn dispatch(&mut self, timestamp: u64, limits: &DeviceLimits, emit: F) -> Result<()> where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> { @@ -1182,12 +1201,13 @@ impl<'a> TcpSocket<'a> { self.remote_last_ts = Some(timestamp); } + // Check if any state needs to be changed because of a timer. if self.timed_out(timestamp) { // If a timeout expires, we should abort the connection. net_debug!("[{}]{}:{}: timeout exceeded", self.debug_id, self.local_endpoint, self.remote_endpoint); self.set_state(State::Closed); - } else if !self.seq_to_transmit(TcpControl::None) { + } else if !self.seq_to_transmit() { if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) { // If a retransmit timer expired, we should resend data starting at the last ACK. net_debug!("[{}]{}:{}: retransmitting at t+{}ms", @@ -1197,6 +1217,25 @@ impl<'a> TcpSocket<'a> { } } + // Decide whether we're sending a packet. + if self.seq_to_transmit() { + // If we have data to transmit and it fits into partner's window, do it. + } else if self.ack_to_transmit() { + // If we have data to acknowledge, do it. + } else if self.window_to_update() { + // If we have window length increase to advertise, do it. + } else if self.state == State::Closed { + // If we need to abort the connection, do it. + } else if self.timer.should_retransmit(timestamp).is_some() { + // If we have packets to retransmit, do it. + } else if self.timer.should_keep_alive(timestamp) { + // If we need to transmit a keep-alive packet, do it. + } else if self.timer.should_close(timestamp) { + // If we have spent enough time in the TIME-WAIT state, close the socket. + } else { + return Err(Error::Exhausted) + } + // Construct the lowered IP representation. // We might need this to calculate the MSS, so do it early. let mut ip_repr = IpRepr::Unspecified { @@ -1279,22 +1318,6 @@ impl<'a> TcpSocket<'a> { } } - if self.seq_to_transmit(repr.control) && repr.segment_len() > 0 { - // If we have data to transmit and it fits into partner's window, do it. - } else if self.ack_to_transmit() && repr.ack_number.is_some() { - // If we have data to acknowledge, do it. - } else if repr.window_len > self.remote_last_win { - // If we have window length increase to advertise, do it. - } else if self.timer.should_retransmit(timestamp).is_some() { - // If we have packets to retransmit, do it. - } else if self.timer.should_keep_alive(timestamp) { - // If we need to transmit a keep-alive packet, do it. - } else if repr.control == TcpControl::Rst { - // If we need to abort the connection, do it. - } else { - return Err(Error::Exhausted) - } - if repr.payload.len() > 0 { net_trace!("[{}]{}:{}: tx buffer: reading {} octets at offset {}", self.debug_id, self.local_endpoint, self.remote_endpoint, @@ -1315,6 +1338,10 @@ impl<'a> TcpSocket<'a> { flags); } + // There might be more than one reason to send a packet. E.g. the keep-alive timer + // has expired, and we also have data in transmit buffer. Since any packet that occupies + // sequence space will elicit an ACK, we only need to send an explicit packet if we + // couldn't fill the sequence space with anything. let is_keep_alive; if self.timer.should_keep_alive(timestamp) && repr.is_empty() { net_trace!("[{}]{}:{}: sending a keep-alive", @@ -1327,13 +1354,20 @@ impl<'a> TcpSocket<'a> { } if repr.control == TcpControl::Syn { - // See RFC 6691 for an explanation of this calculation. + // Fill the MSS option. See RFC 6691 for an explanation of this calculation. let mut max_segment_size = limits.max_transmission_unit; max_segment_size -= ip_repr.buffer_len(); max_segment_size -= repr.header_len(); repr.max_seg_size = Some(max_segment_size as u16); } + // Actually send the packet. If this succeeds, it means the packet is in + // the device buffer, and its transmission is imminent. If not, we might have + // a number of problems, e.g. we need neighbor discovery. + // + // Bailing out if the packet isn't placed in the device buffer allows us + // to not waste time waiting for the retransmit timer on packets that we know + // for sure will not be successfully transmitted. ip_repr.set_payload_len(repr.buffer_len()); emit((ip_repr, repr))?; @@ -1341,7 +1375,8 @@ impl<'a> TcpSocket<'a> { // the keep-alive timer. self.timer.rewind_keep_alive(timestamp, self.keep_alive); - // Leave the rest of the state intact if sending a keep-alive packet. + // Leave the rest of the state intact if sending a keep-alive packet, since those + // carry a fake segment. if is_keep_alive { return Ok(()) } // We've sent a packet successfully, so we can update the internal state now. @@ -1349,15 +1384,14 @@ impl<'a> TcpSocket<'a> { self.remote_last_ack = repr.ack_number; self.remote_last_win = repr.window_len; - if !self.seq_to_transmit(repr.control) && repr.segment_len() > 0 { + if !self.seq_to_transmit() && repr.segment_len() > 0 { // If we've transmitted all data we could (and there was something at all, // data or flag, to transmit, not just an ACK), wind up the retransmit timer. self.timer.set_for_retransmit(timestamp); } - if repr.control == TcpControl::Rst { - // When aborting a connection, forget about it after sending - // the RST packet once. + if self.state == State::Closed { + // When aborting a connection, forget about it after sending a single RST packet. self.local_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default(); } @@ -1365,34 +1399,38 @@ impl<'a> TcpSocket<'a> { Ok(()) } - fn has_data_or_fin_to_transmit(&self) -> bool { - match self.state { - State::FinWait1 | State::LastAck => true, - _ if !self.tx_buffer.is_empty() => true, - _ => false - } - } - pub(crate) fn poll_at(&self) -> Option { - let timeout_poll_at; - match (self.remote_last_ts, self.timeout) { - // If we're transmitting or retransmitting data, we need to poll at the moment - // when the timeout would expire. - (Some(remote_last_ts), Some(timeout)) if self.has_data_or_fin_to_transmit() => - timeout_poll_at = Some(remote_last_ts + timeout), - // If we're transitioning from a long period of inactivity, and have a timeout set, - // request an invocation of dispatch(); that will update self.remote_last_ts. - (None, Some(_timeout)) => - timeout_poll_at = Some(0), - // Otherwise we have no timeout. - (_, _) => - timeout_poll_at = None - } + // The logic here mirrors the beginning of dispatch() closely. + if !self.remote_endpoint.is_specified() { + // No one to talk to, nothing to transmit. + None + } else if self.remote_last_ts.is_none() { + // Socket stopped being quiet recently, we need to acquire a timestamp. + Some(0) + } else if self.state == State::Closed { + // Socket was aborted, we have an RST packet to transmit. + Some(0) + } else if self.seq_to_transmit() || self.ack_to_transmit() || self.window_to_update() { + // We have a data or flag packet to transmit. + Some(0) + } else { + let timeout_poll_at; + match (self.remote_last_ts, self.timeout) { + // If we're transmitting or retransmitting data, we need to poll at the moment + // when the timeout would expire. + (Some(remote_last_ts), Some(timeout)) => + timeout_poll_at = Some(remote_last_ts + timeout), + // Otherwise we have no timeout. + (_, _) => + timeout_poll_at = None + } - [self.timer.poll_at(), timeout_poll_at] - .iter() - .filter_map(|x| *x) - .min() + // We wait for the earliest of our timers to fire. + [self.timer.poll_at(), timeout_poll_at] + .iter() + .filter_map(|x| *x) + .min() + } } } @@ -3114,6 +3152,13 @@ mod test { // Tests for timeouts. // =========================================================================================// + #[test] + fn test_listen_timeout() { + let mut s = socket_listen(); + s.set_timeout(Some(100)); + assert_eq!(s.poll_at(), None); + } + #[test] fn test_connect_timeout() { let mut s = socket(); @@ -3143,7 +3188,7 @@ mod test { let mut s = socket_established(); s.set_timeout(Some(200)); recv!(s, time 250, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), None); + assert_eq!(s.poll_at(), Some(450)); s.send_slice(b"abcdef").unwrap(); assert_eq!(s.poll_at(), Some(0)); recv!(s, time 255, Ok(TcpRepr { @@ -3247,6 +3292,22 @@ mod test { assert_eq!(s.state, State::Closed); } + #[test] + fn test_closed_timeout() { + let mut s = socket_established(); + s.set_timeout(Some(200)); + s.remote_last_ts = Some(100); + s.abort(); + assert_eq!(s.poll_at(), Some(0)); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), None); + } + // =========================================================================================// // Tests for keep-alive. // =========================================================================================//