From cd894460f57179a22a1e8979830bd056b276f2bf Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 5 Mar 2017 03:52:47 +0000 Subject: [PATCH] Implement the TCP SYN-SENT state. --- src/socket/tcp.rs | 153 +++++++++++++++++++++++++++++++++++++++++++--- src/wire/ip.rs | 6 ++ 2 files changed, 151 insertions(+), 8 deletions(-) diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 71d830e..b15fc29 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -25,6 +25,11 @@ impl<'a> SocketBuffer<'a> { } } + fn clear(&mut self) { + self.read_at = 0; + self.length = 0; + } + fn capacity(&self) -> usize { self.storage.len() } @@ -253,6 +258,8 @@ pub struct TcpSocket<'a> { debug_id: usize } +const DEFAULT_MSS: usize = 536; + impl<'a> TcpSocket<'a> { /// Create a socket using the given buffers. pub fn new(rx_buffer: T, tx_buffer: T) -> Socket<'a, 'static> @@ -273,7 +280,7 @@ impl<'a> TcpSocket<'a> { remote_last_seq: TcpSeqNumber(0), remote_last_ack: TcpSeqNumber(0), remote_win_len: 0, - remote_mss: 536, + remote_mss: DEFAULT_MSS, retransmit: Retransmit::new(), tx_buffer: tx_buffer.into(), rx_buffer: rx_buffer.into(), @@ -311,20 +318,77 @@ impl<'a> TcpSocket<'a> { self.state } + fn reset(&mut self) { + self.listen_address = IpAddress::default(); + self.local_endpoint = IpEndpoint::default(); + self.remote_endpoint = IpEndpoint::default(); + self.local_seq_no = TcpSeqNumber(0); + self.remote_seq_no = TcpSeqNumber(0); + self.remote_last_seq = TcpSeqNumber(0); + self.remote_last_ack = TcpSeqNumber(0); + self.remote_win_len = 0; + self.remote_win_len = 0; + self.remote_mss = DEFAULT_MSS; + self.retransmit.reset(); + self.tx_buffer.clear(); + self.rx_buffer.clear(); + } + /// Start listening on the given endpoint. /// /// This function returns an error if the socket was open; see [is_open](#method.is_open). - pub fn listen>(&mut self, endpoint: T) -> Result<(), ()> { - if self.is_open() { return Err(()) } + /// It also returns an error if the specified port is zero. + pub fn listen(&mut self, local_endpoint: T) -> Result<(), ()> + where T: Into { + let local_endpoint = local_endpoint.into(); - let endpoint = endpoint.into(); - self.listen_address = endpoint.addr; - self.local_endpoint = endpoint; + if self.is_open() { return Err(()) } + if local_endpoint.port == 0 { return Err(()) } + + self.reset(); + self.listen_address = local_endpoint.addr; + self.local_endpoint = local_endpoint; self.remote_endpoint = IpEndpoint::default(); self.set_state(State::Listen); Ok(()) } + /// Connect to a given endpoint. + /// + /// The local port must be provided explicitly. Assuming `fn get_ephemeral_port() -> u16` + /// allocates a port from the 49152 to 65535 range, a connection may be established as follows: + /// + /// ```rust,ignore + /// socket.connect((IpAddress::v4(10, 0, 0, 1), 80), get_ephemeral_port()) + /// ``` + /// + /// The local address may optionally be provided. + /// + /// This function returns an error if the socket was open; see [is_open](#method.is_open). + /// It also returns an error if the local or remote port is zero, or if + /// the local or remote address is unspecified. + pub fn connect(&mut self, remote_endpoint: T, local_endpoint: U) -> Result<(), ()> + where T: Into, U: Into { + let remote_endpoint = remote_endpoint.into(); + let local_endpoint = local_endpoint.into(); + + if self.is_open() { return Err(()) } + if remote_endpoint.port == 0 { return Err(()) } + if remote_endpoint.addr.is_unspecified() { return Err(()) } + if local_endpoint.port == 0 { return Err(()) } + if local_endpoint.addr.is_unspecified() { return Err(()) } + + // Carry over the local sequence number. + let local_seq_no = self.local_seq_no; + + self.reset(); + self.local_endpoint = local_endpoint; + self.remote_endpoint = remote_endpoint; + self.local_seq_no = local_seq_no; + self.set_state(State::SynSent); + Ok(()) + } + /// Close the transmit half of the full-duplex connection. /// /// Note that there is no corresponding function for the receive half of the full-duplex @@ -715,6 +779,23 @@ impl<'a> TcpSocket<'a> { self.retransmit.reset(); } + // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED. + (State::SynSent, TcpRepr { + control: TcpControl::Syn, seq_number, ack_number: Some(_), + max_seg_size, .. + }) => { + net_trace!("[{}]{}:{}: received SYN|ACK", + self.debug_id, self.local_endpoint, self.remote_endpoint); + self.remote_last_seq = self.local_seq_no + 1; + self.remote_seq_no = seq_number + 1; + self.remote_last_ack = seq_number; + if let Some(max_seg_size) = max_seg_size { + self.remote_mss = max_seg_size as usize; + } + self.set_state(State::Established); + self.retransmit.reset(); + } + // ACK packets in ESTABLISHED state reset the retransmit timer. (State::Established, TcpRepr { control: TcpControl::None, .. }) => { self.retransmit.reset() @@ -962,8 +1043,10 @@ impl<'a> TcpSocket<'a> { self.retransmit.delay); } - repr.ack_number = Some(ack_number); - self.remote_last_ack = ack_number; + if self.state != State::SynSent { + repr.ack_number = Some(ack_number); + self.remote_last_ack = ack_number; + } // Remember the header length before enabling the MSS option, since that option // only affects SYN packets. @@ -1249,6 +1332,12 @@ mod test { sanity!(s, socket_listen()); } + #[test] + fn test_listen_validation() { + let mut s = socket(); + assert_eq!(s.listen(0), Err(())); + } + #[test] fn test_listen_syn() { let mut s = socket_listen(); @@ -1358,6 +1447,54 @@ mod test { s } + #[test] + fn test_connect_validation() { + let mut s = socket(); + assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), Err(())); + assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(0, 0, 0, 0), 80)), Err(())); + assert_eq!(s.connect(REMOTE_END, (IpAddress::v4(10, 0, 0, 0), 0)), Err(())); + assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), Err(())); + assert_eq!(s.connect((IpAddress::v4(0, 0, 0, 0), 80), LOCAL_END), Err(())); + assert_eq!(s.connect((IpAddress::v4(10, 0, 0, 0), 0), LOCAL_END), Err(())); + } + + #[test] + fn test_syn_sent_sanity() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.connect(REMOTE_END, LOCAL_END).unwrap(); + sanity!(s, socket_syn_sent()); + } + + #[test] + fn test_syn_sent_syn_ack() { + let mut s = socket_syn_sent(); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(1480), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(1400), + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::Established); + sanity!(s, TcpSocket { + retransmit: Retransmit { resend_at: 100, delay: 100 }, + ..socket_established() + }); + } + #[test] fn test_syn_sent_rst() { let mut s = socket_syn_sent(); diff --git a/src/wire/ip.rs b/src/wire/ip.rs index 3e624b8..badec4e 100644 --- a/src/wire/ip.rs +++ b/src/wire/ip.rs @@ -112,6 +112,12 @@ impl From for Endpoint { } } +impl> From<(T, u16)> for Endpoint { + fn from((addr, port): (T, u16)) -> Endpoint { + Endpoint { addr: addr.into(), port: port } + } +} + /// An IP packet representation. /// /// This enum abstracts the various versions of IP packets. It either contains a concrete