From 497aa5919ab66f7512daca2cf43a5996c33c4752 Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 23 Jan 2017 21:02:28 +0000 Subject: [PATCH] Correctly treat TCP ACKs that acknowledge both data and a FIN. --- src/socket/tcp.rs | 66 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 3cc8911..ca5efda 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -575,6 +575,18 @@ impl<'a> TcpSocket<'a> { if !self.remote_endpoint.addr.is_unspecified() && self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) } + // Consider how much the sequence number space differs from the transmit buffer space. + let (sent_syn, sent_fin) = match self.state { + // In SYN-SENT or SYN-RECEIVED, we've just sent a SYN. + State::SynSent | State::SynReceived => (true, false), + // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN. + State::FinWait1 | State::LastAck | State::Closing => (false, true), + // In all other states we've already got acknowledgemetns for + // all of the control flags we sent. + _ => (false, false) + }; + let control_len = (sent_syn as usize) + (sent_fin as usize); + // Reject unacceptable acknowledgements. match (self.state, repr) { // The initial SYN (or whatever) cannot contain an acknowledgement. @@ -609,16 +621,7 @@ impl<'a> TcpSocket<'a> { return Err(Error::Malformed) } // Every acknowledgement must be for transmitted but unacknowledged data. - (state, TcpRepr { ack_number: Some(ack_number), .. }) => { - let control_len = match state { - // In SYN-SENT or SYN-RECEIVED, we've just sent a SYN. - State::SynSent | State::SynReceived => 1, - // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN. - State::FinWait1 | State::LastAck | State::Closing => 1, - // In all other states we've already got acknowledgemetns for - // all of the control flags we sent. - _ => 0 - }; + (_, TcpRepr { ack_number: Some(ack_number), .. }) => { let unacknowledged = self.tx_buffer.len() + control_len; if !(ack_number >= self.local_seq_no && ack_number <= (self.local_seq_no + unacknowledged)) { @@ -708,7 +711,6 @@ impl<'a> TcpSocket<'a> { // ACK packets in the SYN-RECEIVED state change it to ESTABLISHED. (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => { - self.local_seq_no += 1; self.set_state(State::Established); self.retransmit.reset(); } @@ -725,7 +727,6 @@ impl<'a> TcpSocket<'a> { // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2. (State::FinWait1, TcpRepr { control: TcpControl::None, .. }) => { - self.local_seq_no += 1; self.set_state(State::FinWait2); } @@ -745,7 +746,6 @@ impl<'a> TcpSocket<'a> { // ACK packets in CLOSING state change it to TIME-WAIT. (State::Closing, TcpRepr { control: TcpControl::None, .. }) => { - self.local_seq_no += 1; self.set_state(State::TimeWait); self.retransmit.reset(); } @@ -758,7 +758,6 @@ impl<'a> TcpSocket<'a> { // Clear the remote endpoint, or we'll send an RST there. self.set_state(State::Closed); self.remote_endpoint = IpEndpoint::default(); - self.local_seq_no += 1; } _ => { @@ -770,13 +769,23 @@ impl<'a> TcpSocket<'a> { // Dequeue acknowledged octets. if let Some(ack_number) = repr.ack_number { - let ack_length = ack_number - self.local_seq_no; - if ack_length > 0 { + let mut ack_len = ack_number - self.local_seq_no; + // There could have been no data sent before the SYN, so we always remove it + // from the sequence space. + if sent_syn { + ack_len -= 1 + } + // We could've sent data before the FIN, so only remove FIN from the sequence + // space if all of that data is acknowledged. + if sent_fin && self.tx_buffer.len() + 1 == ack_len { + ack_len -= 1 + } + if ack_len > 0 { net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})", self.debug_id, self.local_endpoint, self.remote_endpoint, - ack_length, self.tx_buffer.len() - ack_length); + ack_len, self.tx_buffer.len() - ack_len); } - self.tx_buffer.advance(ack_length); + self.tx_buffer.advance(ack_len); self.local_seq_no = ack_number; } @@ -1962,6 +1971,27 @@ mod test { }]) } + #[test] + fn test_mutual_close_with_data() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + }); + } + // =========================================================================================// // Tests for retransmission on packet loss. // =========================================================================================//