diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index c88f4ef..63a6903 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -54,6 +54,99 @@ impl fmt::Display for State { } } +// Conservative initial RTT estimate. +const RTTE_INITIAL_RTT: u32 = 300; +const RTTE_INITIAL_DEV: u32 = 100; + +// Minimum "safety margin" for the RTO that kicks in when the +// variance gets very low. +const RTTE_MIN_MARGIN: u32 = 5; + +const RTTE_MIN_RTO: u32 = 10; +const RTTE_MAX_RTO: u32 = 10000; + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RttEstimator { + // Using u32 instead of Duration to save space (Duration is i64) + rtt: u32, + deviation: u32, + timestamp: Option<(Instant, TcpSeqNumber)>, + max_seq_sent: Option, + rto_count: u8, +} + +impl Default for RttEstimator { + fn default() -> Self { + Self { + rtt: RTTE_INITIAL_RTT, + deviation: RTTE_INITIAL_DEV, + timestamp: None, + max_seq_sent: None, + rto_count: 0, + } + } +} + +impl RttEstimator { + fn retransmission_timeout(&self) -> Duration { + let margin = RTTE_MIN_MARGIN.max(self.deviation * 4); + let ms = (self.rtt + margin).max(RTTE_MIN_RTO).min(RTTE_MAX_RTO); + Duration::from_millis(ms as u64) + } + + fn sample(&mut self, new_rtt: u32) { + // "Congestion Avoidance and Control", Van Jacobson, Michael J. Karels, 1988 + self.rtt = (self.rtt * 7 + new_rtt + 7) / 8; + let diff = (self.rtt as i32 - new_rtt as i32 ).abs() as u32; + self.deviation = (self.deviation * 3 + diff + 3) / 4; + + self.rto_count = 0; + + let rto = self.retransmission_timeout().millis(); + net_trace!("rtte: sample={:?} rtt={:?} dev={:?} rto={:?}", new_rtt, self.rtt, self.deviation, rto); + } + + fn on_send(&mut self, timestamp: Instant, seq: TcpSeqNumber) { + if self.max_seq_sent.map(|max_seq_sent| seq > max_seq_sent).unwrap_or(true) { + self.max_seq_sent = Some(seq); + if self.timestamp.is_none() { + self.timestamp = Some((timestamp, seq)); + net_trace!("rtte: sampling at seq={:?}", seq); + } + } + } + + fn on_ack(&mut self, timestamp: Instant, seq: TcpSeqNumber) { + if let Some((sent_timestamp, sent_seq)) = self.timestamp { + if seq >= sent_seq { + self.sample((timestamp - sent_timestamp).millis() as u32); + self.timestamp = None; + } + } + } + + fn on_retransmit(&mut self) { + if self.timestamp.is_some() { + net_trace!("rtte: abort sampling due to retransmit"); + } + self.timestamp = None; + self.rto_count = self.rto_count.saturating_add(1); + if self.rto_count >= 3 { + // This happens in 2 scenarios: + // - The RTT is higher than the initial estimate + // - The network conditions change, suddenly making the RTT much higher + // In these cases, the estimator can get stuck, because it can't sample because + // all packets sent would incur a retransmit. To avoid this, force an estimate + // increase if we see 3 consecutive retransmissions without any successful sample. + self.rto_count = 0; + self.rtt *= 2; + let rto = self.retransmission_timeout().millis(); + net_trace!("rtte: too many retransmissions, increasing: rtt={:?} dev={:?} rto={:?}", self.rtt, self.deviation, rto); + } + } +} + #[derive(Debug, Clone, Copy, PartialEq)] enum Timer { Idle { @@ -69,7 +162,6 @@ enum Timer { } } -const RETRANSMIT_DELAY: Duration = Duration { millis: 100 }; const CLOSE_DELAY: Duration = Duration { millis: 10_000 }; impl Default for Timer { @@ -140,12 +232,12 @@ impl Timer { } } - fn set_for_retransmit(&mut self, timestamp: Instant) { + fn set_for_retransmit(&mut self, timestamp: Instant, delay: Duration) { match *self { Timer::Idle { .. } | Timer::FastRetransmit { .. } => { *self = Timer::Retransmit { - expires_at: timestamp + RETRANSMIT_DELAY, - delay: RETRANSMIT_DELAY, + expires_at: timestamp + delay, + delay: delay, } } Timer::Retransmit { expires_at, delay } @@ -189,6 +281,7 @@ pub struct TcpSocket<'a> { pub(crate) meta: SocketMeta, state: State, timer: Timer, + rtte: RttEstimator, assembler: Assembler, rx_buffer: SocketBuffer<'a>, rx_fin_received: bool, @@ -279,6 +372,7 @@ impl<'a> TcpSocket<'a> { meta: SocketMeta::default(), state: State::Closed, timer: Timer::default(), + rtte: RttEstimator::default(), assembler: Assembler::new(rx_buffer.capacity()), tx_buffer: tx_buffer, rx_buffer: rx_buffer, @@ -463,6 +557,7 @@ impl<'a> TcpSocket<'a> { self.state = State::Closed; self.timer = Timer::default(); + self.rtte = RttEstimator::default(); self.assembler = Assembler::new(self.rx_buffer.capacity()); self.tx_buffer.clear(); self.rx_buffer.clear(); @@ -1154,6 +1249,8 @@ impl<'a> TcpSocket<'a> { self.meta.handle, self.local_endpoint, self.remote_endpoint); ack_of_fin = true; } + + self.rtte.on_ack(timestamp, ack_number); } } @@ -1538,6 +1635,7 @@ impl<'a> TcpSocket<'a> { self.meta.handle, self.local_endpoint, self.remote_endpoint, retransmit_delta); self.remote_last_seq = self.local_seq_no; + self.rtte.on_retransmit(); } } @@ -1723,10 +1821,14 @@ impl<'a> TcpSocket<'a> { self.remote_last_ack = repr.ack_number; self.remote_last_win = repr.window_len; + if repr.segment_len() > 0 { + self.rtte.on_send(timestamp, repr.seq_number + repr.segment_len()); + } + if !self.seq_to_transmit() && repr.segment_len() > 0 { // If we've transmitted all data we could (and there was something at all, // data or flag, to transmit, not just an ACK), wind up the retransmit timer. - self.timer.set_for_retransmit(timestamp); + self.timer.set_for_retransmit(timestamp, self.rtte.retransmission_timeout()); } if self.state == State::Closed { @@ -3646,7 +3748,7 @@ mod test { ..RECV_TEMPL })); recv!(s, time 1050, Err(Error::Exhausted)); - recv!(s, time 1100, Ok(TcpRepr { + recv!(s, time 2000, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], @@ -3678,21 +3780,21 @@ mod test { recv!(s, time 50, Err(Error::Exhausted)); - recv!(s, time 100, Ok(TcpRepr { + recv!(s, time 1000, Ok(TcpRepr { control: TcpControl::None, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL }), exact); - recv!(s, time 150, Ok(TcpRepr { + recv!(s, time 1500, Ok(TcpRepr { control: TcpControl::Psh, seq_number: LOCAL_SEQ + 1 + 6, ack_number: Some(REMOTE_SEQ + 1), payload: &b"012345"[..], ..RECV_TEMPL }), exact); - recv!(s, time 200, Err(Error::Exhausted)); + recv!(s, time 1550, Err(Error::Exhausted)); } #[test] @@ -3705,7 +3807,7 @@ mod test { max_seg_size: Some(BASE_MSS), ..RECV_TEMPL })); - recv!(s, time 150, Ok(TcpRepr { // retransmit + recv!(s, time 750, Ok(TcpRepr { // retransmit control: TcpControl::Syn, seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4527,9 +4629,9 @@ mod test { #[test] fn test_established_timeout() { let mut s = socket_established(); - s.set_timeout(Some(Duration::from_millis(200))); + s.set_timeout(Some(Duration::from_millis(1000))); recv!(s, time 250, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(450))); + assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(1250))); s.send_slice(b"abcdef").unwrap(); assert_eq!(s.poll_at(), PollAt::Now); recv!(s, time 255, Ok(TcpRepr { @@ -4538,15 +4640,15 @@ mod test { payload: &b"abcdef"[..], ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(355))); - recv!(s, time 355, Ok(TcpRepr { + assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(955))); + recv!(s, time 955, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(455))); - recv!(s, time 500, Ok(TcpRepr { + assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(1255))); + recv!(s, time 1255, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 6, ack_number: Some(REMOTE_SEQ + 1), @@ -4596,15 +4698,14 @@ mod test { #[test] fn test_fin_wait_1_timeout() { let mut s = socket_fin_wait_1(); - s.set_timeout(Some(Duration::from_millis(200))); + s.set_timeout(Some(Duration::from_millis(1000))); recv!(s, time 100, Ok(TcpRepr { control: TcpControl::Fin, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200))); - recv!(s, time 400, Ok(TcpRepr { + recv!(s, time 1100, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 1, ack_number: Some(REMOTE_SEQ + 1), @@ -4616,15 +4717,14 @@ mod test { #[test] fn test_last_ack_timeout() { let mut s = socket_last_ack(); - s.set_timeout(Some(Duration::from_millis(200))); + s.set_timeout(Some(Duration::from_millis(1000))); recv!(s, time 100, Ok(TcpRepr { control: TcpControl::Fin, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1 + 1), ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200))); - recv!(s, time 400, Ok(TcpRepr { + recv!(s, time 1100, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 1, ack_number: Some(REMOTE_SEQ + 1 + 1), @@ -5052,13 +5152,14 @@ mod test { #[test] fn test_timer_retransmit() { + const RTO: Duration = Duration::from_millis(100); let mut r = Timer::default(); assert_eq!(r.should_retransmit(Instant::from_secs(1)), None); - r.set_for_retransmit(Instant::from_millis(1000)); + r.set_for_retransmit(Instant::from_millis(1000), RTO); assert_eq!(r.should_retransmit(Instant::from_millis(1000)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1050)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1101)), Some(Duration::from_millis(101))); - r.set_for_retransmit(Instant::from_millis(1101)); + r.set_for_retransmit(Instant::from_millis(1101), RTO); assert_eq!(r.should_retransmit(Instant::from_millis(1101)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1150)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1200)), None); @@ -5067,4 +5168,23 @@ mod test { assert_eq!(r.should_retransmit(Instant::from_millis(1350)), None); } + #[test] + fn test_rtt_estimator() { + #[cfg(feature = "log")] + init_logger(); + + let mut r = RttEstimator::default(); + + let rtos = &[ + 751, 766, 755, 731, 697, 656, 613, 567, + 523, 484, 445, 411, 378, 350, 322, 299, + 280, 261, 243, 229, 215, 206, 197, 188 + ]; + + for &rto in rtos { + r.sample(100); + assert_eq!(r.retransmission_timeout(), Duration::from_millis(rto)); + } + } + }