diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 1263b1d..732ccfa 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -727,6 +727,29 @@ impl<'a> TcpSocket<'a> { } } + // Compute the amount of acknowledged octets, removing the SYN and FIN bits + // from the sequence space. + let mut ack_len = 0; + let mut ack_of_fin = false; + if repr.control != TcpControl::Rst { + if let Some(ack_number) = repr.ack_number { + ack_len = ack_number - self.local_seq_no; + // There could have been no data sent before the SYN, so we always remove it + // from the sequence space. + if sent_syn { + ack_len -= 1 + } + // We could've sent data before the FIN, so only remove FIN from the sequence + // space if all of that data is acknowledged. + if sent_fin && self.tx_buffer.len() + 1 == ack_len { + ack_len -= 1; + net_trace!("[{}]{}:{}: received ACK of FIN", + self.debug_id, self.local_endpoint, self.remote_endpoint); + ack_of_fin = true; + } + } + } + // Validate and update the state. match (self.state, repr) { // RSTs are ignored in the LISTEN state. @@ -809,9 +832,9 @@ impl<'a> TcpSocket<'a> { } // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2, if we've already - // sent everything in the transmit buffer, and reset the retransmit timer. + // sent everything in the transmit buffer. If not, they reset the retransmit timer. (State::FinWait1, TcpRepr { control: TcpControl::None, .. }) => { - if self.tx_buffer.empty() { + if ack_of_fin { self.set_state(State::FinWait2); } else { self.retransmit.reset(); @@ -821,7 +844,11 @@ impl<'a> TcpSocket<'a> { // FIN packets in FIN-WAIT-1 state change it to CLOSING. (State::FinWait1, TcpRepr { control: TcpControl::Fin, .. }) => { self.remote_seq_no += 1; - self.set_state(State::Closing); + if ack_of_fin { + self.set_state(State::TimeWait); + } else { + self.set_state(State::Closing); + } self.retransmit.reset(); } @@ -834,8 +861,11 @@ impl<'a> TcpSocket<'a> { // ACK packets in CLOSING state change it to TIME-WAIT. (State::Closing, TcpRepr { control: TcpControl::None, .. }) => { - self.set_state(State::TimeWait); - self.retransmit.reset(); + if ack_of_fin { + self.set_state(State::TimeWait); + } else { + self.retransmit.reset(); + } } // ACK packets in CLOSE-WAIT state reset the retransmit timer. @@ -858,34 +888,16 @@ impl<'a> TcpSocket<'a> { } // Dequeue acknowledged octets. - if let Some(ack_number) = repr.ack_number { - let mut ack_len = ack_number - self.local_seq_no; - // There could have been no data sent before the SYN, so we always remove it - // from the sequence space. - if sent_syn { - ack_len -= 1 - } - // We could've sent data before the FIN, so only remove FIN from the sequence - // space if all of that data is acknowledged. - if sent_fin && self.tx_buffer.len() + 1 == ack_len { - ack_len -= 1; - net_trace!("[{}]{}:{}: received ACK of FIN", - self.debug_id, self.local_endpoint, self.remote_endpoint); - // If we've just switched from the FIN-WAIT-1 state to the CLOSING state - // because we've received a FIN, and the same packet simultaneously acknowledges - // the FIN we've sent, this is our only opportunity to move to the TIME-WAIT state. - match self.state { - State::Closing => - self.set_state(State::TimeWait), - _ => () - } - } - if ack_len > 0 { - net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})", - self.debug_id, self.local_endpoint, self.remote_endpoint, - ack_len, self.tx_buffer.len() - ack_len); - } + if ack_len > 0 { + net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})", + self.debug_id, self.local_endpoint, self.remote_endpoint, + ack_len, self.tx_buffer.len() - ack_len); self.tx_buffer.advance(ack_len); + } + + // We've processed everything in the incoming segment, so advance the local + // sequence number past it. + if let Some(ack_number) = repr.ack_number { self.local_seq_no = ack_number; } @@ -1016,6 +1028,7 @@ impl<'a> TcpSocket<'a> { net_trace!("[{}]{}:{}: sending FIN|ACK", self.debug_id, self.local_endpoint, self.remote_endpoint); repr.control = TcpControl::Fin; + self.remote_last_seq += 1; should_send = true; } _ => () @@ -1489,10 +1502,7 @@ mod test { ..RECV_TEMPL }]); assert_eq!(s.state, State::Established); - sanity!(s, TcpSocket { - retransmit: Retransmit { resend_at: 100, delay: 100 }, - ..socket_established() - }); + sanity!(s, socket_established(), retransmit: false); } #[test] @@ -1781,7 +1791,10 @@ mod test { ..SEND_TEMPL }); assert_eq!(s.state, State::FinWait2); - sanity!(&s, socket_fin_wait_2(), retransmit: false); + sanity!(s, TcpSocket { + remote_last_seq: LOCAL_SEQ + 1 + 1, + ..socket_fin_wait_2() + }, retransmit: false); } #[test] @@ -1800,7 +1813,10 @@ mod test { ..SEND_TEMPL }); assert_eq!(s.state, State::Closing); - sanity!(s, socket_closing()); + sanity!(s, TcpSocket { + remote_last_seq: LOCAL_SEQ + 1 + 1, + ..socket_closing() + }); } #[test] @@ -1885,7 +1901,7 @@ mod test { ..SEND_TEMPL }); assert_eq!(s.state, State::TimeWait); - sanity!(s, socket_time_wait(true)); + sanity!(s, socket_time_wait(true), retransmit: false); } #[test] @@ -2175,7 +2191,7 @@ mod test { } #[test] - fn test_mutual_close_with_data() { + fn test_mutual_close_with_data_1() { let mut s = socket_established(); s.send_slice(b"abcdef").unwrap(); s.close(); @@ -2195,6 +2211,39 @@ mod test { }); } + #[test] + fn test_mutual_close_with_data_2() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::FinWait2); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::TimeWait); + } + // =========================================================================================// // Tests for retransmission on packet loss. // =========================================================================================//