diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index d06b43d..79d8f15 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -68,8 +68,8 @@ impl<'a> SocketBuffer<'a> { dest.copy_from_slice(data); } - fn clamp_reader(&self, mut size: usize) -> (usize, usize) { - let read_at = self.read_at; + fn clamp_reader(&self, offset: usize, mut size: usize) -> (usize, usize) { + let read_at = (self.read_at + offset) % self.storage.len(); // We can't dequeue more than was queued. if size > self.length { size = self.length } // We can't contiguously dequeue past the end of the storage. @@ -79,23 +79,24 @@ impl<'a> SocketBuffer<'a> { (read_at, size) } - fn peek(&self, size: usize) -> &[u8] { - let (read_at, size) = self.clamp_reader(size); + #[allow(dead_code)] // only used in tests + fn dequeue(&mut self, size: usize) -> &[u8] { + let (read_at, size) = self.clamp_reader(0, size); + self.read_at = (self.read_at + size) % self.storage.len(); + self.length -= size; + &self.storage[read_at..read_at + size] + } + + fn peek(&self, offset: usize, size: usize) -> &[u8] { + if offset > self.length { panic!("peeking {} octets past free space", offset) } + let (read_at, size) = self.clamp_reader(offset, size); &self.storage[read_at..read_at + size] } fn advance(&mut self, size: usize) { - let (read_at, size) = self.clamp_reader(size); - self.read_at = (read_at + size) % self.storage.len(); - self.length -= size; - } - - #[allow(dead_code)] // only used in tests - fn dequeue(&mut self, size: usize) -> &[u8] { - let (read_at, size) = self.clamp_reader(size); + if size > self.length { panic!("advancing {} octets into free space", size) } self.read_at = (self.read_at + size) % self.storage.len(); self.length -= size; - &self.storage[read_at..read_at + size] } } @@ -162,13 +163,33 @@ impl Retransmit { /// A Transmission Control Protocol data stream. #[derive(Debug)] pub struct TcpSocket<'a> { + /// State of the socket. state: State, + /// Address passed to `listen()`. `listen_address` is set when `listen()` is called and + /// used every time the socket is reset back to the `LISTEN` state. listen_address: IpAddress, + /// Current local endpoint. This is used for both filtering the incoming packets and + /// setting the source address. When listening or initiating connection on/from + /// an unspecified address, this field is updated with the chosen source address before + /// any packets are sent. local_endpoint: IpEndpoint, + /// Current remote endpoint. This is used for both filtering the incoming packets and + /// setting the destination address. remote_endpoint: IpEndpoint, + /// The sequence number corresponding to the beginning of the transmit buffer. + /// I.e. an ACK(local_seq_no+n) packet removes n bytes from the transmit buffer. local_seq_no: i32, + /// The sequence number corresponding to the beginning of the receive buffer. + /// I.e. userspace reading n bytes adds n to remote_seq_no. remote_seq_no: i32, + /// The last sequence number sent. + /// I.e. in an idle socket, local_seq_no+tx_buffer.len(). + remote_last_seq: i32, + /// The last acknowledgement number sent. + /// I.e. in an idle socket, remote_seq_no+rx_buffer.len(). remote_last_ack: i32, + /// The speculative remote window size. + /// I.e. the actual remote window size minus the count of in-flight octets. remote_win_len: usize, retransmit: Retransmit, rx_buffer: SocketBuffer<'a>, @@ -192,8 +213,9 @@ impl<'a> TcpSocket<'a> { remote_endpoint: IpEndpoint::default(), local_seq_no: 0, remote_seq_no: 0, - remote_win_len: 0, + remote_last_seq: 0, remote_last_ack: 0, + remote_win_len: 0, retransmit: Retransmit::new(), tx_buffer: tx_buffer.into(), rx_buffer: rx_buffer.into() @@ -252,7 +274,7 @@ impl<'a> TcpSocket<'a> { pub fn send(&mut self, size: usize) -> &mut [u8] { let buffer = self.tx_buffer.enqueue(size); if buffer.len() > 0 { - net_trace!("tcp:{}:{}: buffer to send {} octets", + net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets", self.local_endpoint, self.remote_endpoint, buffer.len()); } buffer @@ -431,6 +453,7 @@ impl<'a> TcpSocket<'a> { // SYN|ACK packets in the SYN_RECEIVED state change it to ESTABLISHED. (State::SynReceived, TcpRepr { control: TcpControl::None, .. }) => { + self.remote_last_seq = self.local_seq_no + 1; self.set_state(State::Established); self.retransmit.reset() } @@ -500,24 +523,36 @@ impl<'a> TcpSocket<'a> { repr.control = TcpControl::Syn; net_trace!("tcp:{}:{}: sending SYN|ACK", self.local_endpoint, self.remote_endpoint); - self.remote_last_ack = self.remote_seq_no; } State::Established => { - if self.tx_buffer.len() > 0 && self.remote_win_len > 0 { - if !self.retransmit.check() { return Err(Error::Exhausted) } + // See if we should send data to the remote end because: + // 1. the retransmit timer has expired, or... + let mut may_send = self.retransmit.check(); + // 2. we've got new data in the transmit buffer. + let remote_next_seq = self.local_seq_no + self.tx_buffer.len() as i32; + if self.remote_last_seq != remote_next_seq { + may_send = true; + } + if self.tx_buffer.len() > 0 && self.remote_win_len > 0 && may_send { // We can send something, so let's do that. + let offset = self.remote_last_seq - self.local_seq_no; let mut size = self.remote_win_len; // Clamp to MSS. Currently we only support the default MSS value. if size > 536 { size = 536 } // Extract data from the buffer. This may return less than what we want, // in case it's not possible to extract a contiguous slice. - let data = self.tx_buffer.peek(size); - - net_trace!("tcp:{}:{}: sending {} octets", + let data = self.tx_buffer.peek(offset as usize, size); + // Send the extracted data. + net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets", self.local_endpoint, self.remote_endpoint, data.len()); repr.payload = data; + // Speculatively shrink the remote window. This will get updated the next + // time we receive a packet. + self.remote_win_len -= data.len(); + // Advance the in-flight sequence number. + self.remote_last_seq += data.len() as i32; } else if self.remote_last_ack != ack_number { // We don't have anything to send, or can't because the remote end does not // have any space to accept it, but we haven't yet acknowledged everything @@ -822,12 +857,14 @@ mod test { s.remote_endpoint = REMOTE_END; s.local_seq_no = LOCAL_SEQ + 1; s.remote_seq_no = REMOTE_SEQ + 1; + s.remote_last_seq = LOCAL_SEQ + 1; + s.remote_last_ack = REMOTE_SEQ + 1; s.remote_win_len = 128; s } #[test] - fn test_established_receive() { + fn test_established_recv() { let mut s = socket_established(); send!(s, [TcpRepr { seq_number: REMOTE_SEQ + 1, @@ -847,6 +884,7 @@ mod test { #[test] fn test_established_send() { let mut s = socket_established(); + // First roundtrip after establishing. s.tx_buffer.enqueue_slice(b"abcdef"); recv!(s, [TcpRepr { seq_number: LOCAL_SEQ + 1, @@ -861,6 +899,20 @@ mod test { ..SEND_TEMPL }]); assert_eq!(s.tx_buffer.len(), 0); + // Second roundtrip. + s.tx_buffer.enqueue_slice(b"foobar"); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }]); + send!(s, [TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + ..SEND_TEMPL + }]); + assert_eq!(s.tx_buffer.len(), 0); } #[test]