diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index d4adfbe..e9885da 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -56,11 +56,6 @@ impl fmt::Display for State { } } -/// Initial sequence number. This used to be 0, but some servers don't behave correctly -/// with that, so we use a non-zero starting sequence number. TODO: randomize instead. -/// https://github.com/smoltcp-rs/smoltcp/issues/489 -const INITIAL_SEQ_NO: TcpSeqNumber = TcpSeqNumber(42); - // Conservative initial RTT estimate. const RTTE_INITIAL_RTT: u32 = 300; const RTTE_INITIAL_DEV: u32 = 100; @@ -430,7 +425,7 @@ impl<'a> TcpSocket<'a> { listen_address: IpAddress::default(), local_endpoint: IpEndpoint::default(), remote_endpoint: IpEndpoint::default(), - local_seq_no: INITIAL_SEQ_NO, + local_seq_no: TcpSeqNumber::default(), remote_seq_no: TcpSeqNumber::default(), remote_last_seq: TcpSeqNumber::default(), remote_last_ack: None, @@ -658,7 +653,7 @@ impl<'a> TcpSocket<'a> { self.listen_address = IpAddress::default(); self.local_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default(); - self.local_seq_no = INITIAL_SEQ_NO; + self.local_seq_no = TcpSeqNumber::default(); self.remote_seq_no = TcpSeqNumber::default(); self.remote_last_seq = TcpSeqNumber::default(); self.remote_last_ack = None; @@ -751,18 +746,24 @@ impl<'a> TcpSocket<'a> { ..local_endpoint }; - // Carry over the local sequence number. - let local_seq_no = self.local_seq_no; - self.reset(); self.local_endpoint = local_endpoint; self.remote_endpoint = remote_endpoint; - self.local_seq_no = local_seq_no; - self.remote_last_seq = local_seq_no; self.set_state(State::SynSent); + + let seq = Self::random_seq_no(); + self.local_seq_no = seq; + self.remote_last_seq = seq; Ok(()) } + fn random_seq_no() -> TcpSeqNumber { + #[cfg(test)] + return TcpSeqNumber(10000); + #[cfg(not(test))] + return TcpSeqNumber(crate::rand::rand_u32() as i32); + } + /// Close the transmit half of the full-duplex connection. /// /// Note that there is no corresponding function for the receive half of the full-duplex @@ -1575,8 +1576,7 @@ impl<'a> TcpSocket<'a> { self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), repr.src_port); - // FIXME: use something more secure here - self.local_seq_no = TcpSeqNumber(!repr.seq_number.0); + self.local_seq_no = Self::random_seq_no(); self.remote_seq_no = repr.seq_number + 1; self.remote_last_seq = self.local_seq_no; self.remote_has_sack = repr.sack_permitted;