diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index c2c2f5e..f10cf5d 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -272,7 +272,7 @@ pub struct TcpSocket<'a> { remote_last_seq: TcpSeqNumber, /// The last acknowledgement number sent. /// I.e. in an idle socket, remote_seq_no+rx_buffer.len(). - remote_last_ack: TcpSeqNumber, + remote_last_ack: Option, /// The speculative remote window size. /// I.e. the actual remote window size minus the count of in-flight octets. remote_win_len: usize, @@ -304,7 +304,7 @@ impl<'a> TcpSocket<'a> { local_seq_no: TcpSeqNumber::default(), remote_seq_no: TcpSeqNumber::default(), remote_last_seq: TcpSeqNumber::default(), - remote_last_ack: TcpSeqNumber::default(), + remote_last_ack: None, remote_win_len: 0, remote_mss: DEFAULT_MSS, }) @@ -350,7 +350,7 @@ impl<'a> TcpSocket<'a> { self.local_seq_no = TcpSeqNumber::default(); self.remote_seq_no = TcpSeqNumber::default(); self.remote_last_seq = TcpSeqNumber::default(); - self.remote_last_ack = TcpSeqNumber::default(); + self.remote_last_ack = None; self.remote_win_len = 0; self.remote_mss = DEFAULT_MSS; self.timer.reset(); @@ -739,7 +739,7 @@ impl<'a> TcpSocket<'a> { // and an acknowledgment indicating the next sequence number expected // to be received. reply_repr.seq_number = self.remote_last_seq; - reply_repr.ack_number = Some(self.remote_last_ack); + reply_repr.ack_number = self.remote_last_ack; reply_repr.window_len = self.rx_buffer.window() as u16; (ip_reply_repr, reply_repr) @@ -974,7 +974,6 @@ impl<'a> TcpSocket<'a> { self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); self.remote_seq_no = repr.seq_number + 1; self.remote_last_seq = self.local_seq_no + 1; - self.remote_last_ack = repr.seq_number; if let Some(max_seg_size) = repr.max_seg_size { self.remote_mss = max_seg_size as usize; } @@ -1077,7 +1076,7 @@ impl<'a> TcpSocket<'a> { self.rx_buffer.enqueue_slice(repr.payload); // Send an acknowledgement. - self.remote_last_ack = self.remote_seq_no + self.rx_buffer.len(); + self.remote_last_ack = Some(self.remote_seq_no + self.rx_buffer.len()); Ok(Some(self.ack_reply(ip_repr, &repr))) } else { // No data to acknowledge; the logic to acknowledge SYN and FIN flags @@ -1091,7 +1090,11 @@ impl<'a> TcpSocket<'a> { } fn ack_to_transmit(&self) -> bool { - self.remote_last_ack < self.remote_seq_no + self.rx_buffer.len() + if let Some(remote_last_ack) = self.remote_last_ack { + remote_last_ack < self.remote_seq_no + self.rx_buffer.len() + } else { + true + } } pub(crate) fn dispatch(&mut self, timestamp: u64, limits: &DeviceLimits, @@ -1182,7 +1185,7 @@ impl<'a> TcpSocket<'a> { if self.seq_to_transmit(repr.control) && repr.segment_len() > 0 { // If we have data to transmit and it fits into partner's window, do it. - } else if self.ack_to_transmit() { + } else if self.ack_to_transmit() && repr.ack_number.is_some() { // If we have data to acknowledge, do it. } else if self.timer.should_retransmit(timestamp).is_some() { // If we have packets to retransmit, do it. @@ -1248,7 +1251,7 @@ impl<'a> TcpSocket<'a> { // We've sent a packet successfully, so we can update the internal state now. self.remote_last_seq = repr.seq_number + repr.segment_len(); - self.remote_last_ack = repr.ack_number.unwrap_or_default(); + self.remote_last_ack = repr.ack_number; if !self.seq_to_transmit(repr.control) && repr.segment_len() > 0 { // If we've transmitted all data could (and there was something at all, @@ -1648,7 +1651,7 @@ mod test { }))); assert_eq!(s.state, State::CloseWait); sanity!(s, TcpSocket { - remote_last_ack: REMOTE_SEQ + 1 + 6 + 1, + remote_last_ack: Some(REMOTE_SEQ + 1 + 6 + 1), ..socket_close_wait() }); } @@ -1843,7 +1846,7 @@ mod test { s.state = State::Established; s.local_seq_no = LOCAL_SEQ + 1; s.remote_last_seq = LOCAL_SEQ + 1; - s.remote_last_ack = REMOTE_SEQ + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1); s } @@ -2214,7 +2217,7 @@ mod test { s.state = State::TimeWait; s.remote_seq_no = REMOTE_SEQ + 1 + 1; if from_closing { - s.remote_last_ack = REMOTE_SEQ + 1 + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); } s.timer = Timer::Close { expires_at: 1_000 + CLOSE_DELAY }; s @@ -2284,7 +2287,7 @@ mod test { let mut s = socket_established(); s.state = State::CloseWait; s.remote_seq_no = REMOTE_SEQ + 1 + 1; - s.remote_last_ack = REMOTE_SEQ + 1 + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); s }