diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 0ca3c27..be307c9 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -49,7 +49,30 @@ pub enum Socket<'a, 'b: 'a> { __Nonexhaustive } +macro_rules! dispatch_socket { + ($self_:expr, |$socket:ident [$( $mut_:tt )*]| $code:expr) => ({ + match $self_ { + &$( $mut_ )* Socket::Udp(ref $( $mut_ )* $socket) => $code, + &$( $mut_ )* Socket::Tcp(ref $( $mut_ )* $socket) => $code, + &$( $mut_ )* Socket::__Nonexhaustive => unreachable!() + } + }) +} + impl<'a, 'b> Socket<'a, 'b> { + /// Return the debug identifier. + pub fn debug_id(&self) -> usize { + dispatch_socket!(self, |socket []| socket.debug_id()) + } + + /// Set the debug identifier. + /// + /// The debug identifier is a number printed in socket trace messages. + /// It could as well be used by the user code. + pub fn set_debug_id(&mut self, id: usize) { + dispatch_socket!(self, |socket [mut]| socket.set_debug_id(id)) + } + /// Process a packet received from a network interface. /// /// This function checks if the packet contained in the payload matches the socket endpoint, @@ -59,13 +82,7 @@ impl<'a, 'b> Socket<'a, 'b> { /// This function is used internally by the networking stack. pub fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> { - match self { - &mut Socket::Udp(ref mut socket) => - socket.process(timestamp, ip_repr, payload), - &mut Socket::Tcp(ref mut socket) => - socket.process(timestamp, ip_repr, payload), - &mut Socket::__Nonexhaustive => unreachable!() - } + dispatch_socket!(self, |socket [mut]| socket.process(timestamp, ip_repr, payload)) } /// Prepare a packet to be transmitted to a network interface. @@ -77,13 +94,7 @@ impl<'a, 'b> Socket<'a, 'b> { /// This function is used internally by the networking stack. pub fn dispatch(&mut self, timestamp: u64, emit: &mut F) -> Result where F: FnMut(&IpRepr, &IpPayload) -> Result { - match self { - &mut Socket::Udp(ref mut socket) => - socket.dispatch(timestamp, emit), - &mut Socket::Tcp(ref mut socket) => - socket.dispatch(timestamp, emit), - &mut Socket::__Nonexhaustive => unreachable!() - } + dispatch_socket!(self, |socket [mut]| socket.dispatch(timestamp, emit)) } } diff --git a/src/socket/set.rs b/src/socket/set.rs index 4c8c43c..1cb3151 100644 --- a/src/socket/set.rs +++ b/src/socket/set.rs @@ -29,9 +29,10 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> { /// /// # Panics /// This function panics if the storage is fixed-size (not a `Vec`) and is full. - pub fn add(&mut self, socket: Socket<'b, 'c>) -> Handle { + pub fn add(&mut self, mut socket: Socket<'b, 'c>) -> Handle { for (index, slot) in self.sockets.iter_mut().enumerate() { if slot.is_none() { + socket.set_debug_id(index); *slot = Some(socket); return Handle { index: index } } @@ -42,8 +43,10 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> { panic!("adding a socket to a full SocketSet") } ManagedSlice::Owned(ref mut sockets) => { + let index = sockets.len(); + socket.set_debug_id(index); sockets.push(Some(socket)); - Handle { index: sockets.len() - 1 } + Handle { index: index } } } } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 29803dd..d06ab64 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -247,7 +247,8 @@ pub struct TcpSocket<'a> { remote_win_len: usize, retransmit: Retransmit, rx_buffer: SocketBuffer<'a>, - tx_buffer: SocketBuffer<'a> + tx_buffer: SocketBuffer<'a>, + debug_id: usize } impl<'a> TcpSocket<'a> { @@ -272,10 +273,24 @@ impl<'a> TcpSocket<'a> { remote_win_len: 0, retransmit: Retransmit::new(), tx_buffer: tx_buffer.into(), - rx_buffer: rx_buffer.into() + rx_buffer: rx_buffer.into(), + debug_id: 0 }) } + /// Return the debug identifier. + pub fn debug_id(&self) -> usize { + self.debug_id + } + + /// Set the debug identifier. + /// + /// The debug identifier is a number printed in socket trace messages. + /// It could as well be used by the user code. + pub fn set_debug_id(&mut self, id: usize) { + self.debug_id = id + } + /// Return the local endpoint. #[inline] pub fn local_endpoint(&self) -> IpEndpoint { @@ -436,8 +451,8 @@ impl<'a> TcpSocket<'a> { let old_length = self.tx_buffer.len(); let buffer = self.tx_buffer.enqueue(size); if buffer.len() > 0 { - net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets (now {})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: tx buffer: enqueueing {} octets (now {})", + self.debug_id, self.local_endpoint, self.remote_endpoint, buffer.len(), old_length + buffer.len()); self.retransmit.reset(); } @@ -471,8 +486,8 @@ impl<'a> TcpSocket<'a> { let buffer = self.rx_buffer.dequeue(size); self.remote_seq_no += buffer.len(); if buffer.len() > 0 { - net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets (now {})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: rx buffer: dequeueing {} octets (now {})", + self.debug_id, self.local_endpoint, self.remote_endpoint, buffer.len(), old_length - buffer.len()); } Ok(buffer) @@ -501,11 +516,13 @@ impl<'a> TcpSocket<'a> { fn set_state(&mut self, state: State) { if self.state != state { if self.remote_endpoint.addr.is_unspecified() { - net_trace!("tcp:{}: state={}=>{}", - self.local_endpoint, self.state, state); + net_trace!("[{}]{}: state={}=>{}", + self.debug_id, self.local_endpoint, + self.state, state); } else { - net_trace!("tcp:{}:{}: state={}=>{}", - self.local_endpoint, self.remote_endpoint, self.state, state); + net_trace!("[{}]{}:{}: state={}=>{}", + self.debug_id, self.local_endpoint, self.remote_endpoint, + self.state, state); } } self.state = state @@ -534,25 +551,25 @@ impl<'a> TcpSocket<'a> { match (self.state, repr) { // The initial SYN (or whatever) cannot contain an acknowledgement. (State::Listen, TcpRepr { ack_number: Some(_), .. }) => { - net_trace!("tcp:{}:{}: ACK received by a socket in LISTEN state", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: ACK received by a socket in LISTEN state", + self.debug_id, self.local_endpoint, self.remote_endpoint); return Err(Error::Malformed) } (State::Listen, TcpRepr { ack_number: None, .. }) => (), // An RST received in response to initial SYN is acceptable if it acknowledges // the initial SYN. (State::SynSent, TcpRepr { control: TcpControl::Rst, ack_number: None, .. }) => { - net_trace!("tcp:{}:{}: unacceptable RST (expecting RST|ACK) \ + net_trace!("[{}]{}:{}: unacceptable RST (expecting RST|ACK) \ in response to initial SYN", - self.local_endpoint, self.remote_endpoint); + self.debug_id, self.local_endpoint, self.remote_endpoint); return Err(Error::Malformed) } (State::SynSent, TcpRepr { control: TcpControl::Rst, ack_number: Some(ack_number), .. }) => { if ack_number != self.local_seq_no { - net_trace!("tcp:{}:{}: unacceptable RST|ACK in response to initial SYN", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: unacceptable RST|ACK in response to initial SYN", + self.debug_id, self.local_endpoint, self.remote_endpoint); return Err(Error::Malformed) } } @@ -560,8 +577,8 @@ impl<'a> TcpSocket<'a> { (_, TcpRepr { control: TcpControl::Rst, .. }) => (), // Every packet after the initial SYN must be an acknowledgement. (_, TcpRepr { ack_number: None, .. }) => { - net_trace!("tcp:{}:{}: expecting an ACK", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: expecting an ACK", + self.debug_id, self.local_endpoint, self.remote_endpoint); return Err(Error::Malformed) } // Every acknowledgement must be for transmitted but unacknowledged data. @@ -578,8 +595,8 @@ impl<'a> TcpSocket<'a> { let unacknowledged = self.tx_buffer.len() + control_len; if !(ack_number >= self.local_seq_no && ack_number <= (self.local_seq_no + unacknowledged)) { - net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}...{})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: unacceptable ACK ({} not in {}...{})", + self.debug_id, self.local_endpoint, self.remote_endpoint, ack_number, self.local_seq_no, self.local_seq_no + unacknowledged); return Err(Error::Dropped) } @@ -595,13 +612,13 @@ impl<'a> TcpSocket<'a> { (_, TcpRepr { seq_number, .. }) => { let next_remote_seq = self.remote_seq_no + self.rx_buffer.len(); if seq_number > next_remote_seq { - net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: unacceptable SEQ ({} not in {}..)", + self.debug_id, self.local_endpoint, self.remote_endpoint, seq_number, next_remote_seq); return Err(Error::Dropped) } else if seq_number != next_remote_seq { - net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: duplicate SEQ ({} in ..{})", + self.debug_id, self.local_endpoint, self.remote_endpoint, seq_number, next_remote_seq); // If we've seen this sequence number already but the remote end is not aware // of that, make sure we send the acknowledgement again. @@ -620,8 +637,8 @@ impl<'a> TcpSocket<'a> { // RSTs in SYN-RECEIVED flip the socket back to the LISTEN state. (State::SynReceived, TcpRepr { control: TcpControl::Rst, .. }) => { - net_trace!("tcp:{}:{}: received RST", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: received RST", + self.debug_id, self.local_endpoint, self.remote_endpoint); self.local_endpoint.addr = self.listen_address; self.remote_endpoint = IpEndpoint::default(); self.set_state(State::Listen); @@ -630,8 +647,8 @@ impl<'a> TcpSocket<'a> { // RSTs in any other state close the socket. (_, TcpRepr { control: TcpControl::Rst, .. }) => { - net_trace!("tcp:{}:{}: received RST", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: received RST", + self.debug_id, self.local_endpoint, self.remote_endpoint); self.set_state(State::Closed); self.local_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default(); @@ -642,8 +659,8 @@ impl<'a> TcpSocket<'a> { (State::Listen, TcpRepr { src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, .. }) => { - net_trace!("tcp:{}: received SYN", - self.local_endpoint); + net_trace!("[{}]{}: received SYN", + self.debug_id, self.local_endpoint); self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), dst_port); self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port); // FIXME: use something more secure here @@ -710,8 +727,8 @@ impl<'a> TcpSocket<'a> { } _ => { - net_trace!("tcp:{}:{}: unexpected packet {}", - self.local_endpoint, self.remote_endpoint, repr); + net_trace!("[{}]{}:{}: unexpected packet {}", + self.debug_id, self.local_endpoint, self.remote_endpoint, repr); return Err(Error::Malformed) } } @@ -720,8 +737,8 @@ impl<'a> TcpSocket<'a> { if let Some(ack_number) = repr.ack_number { let ack_length = ack_number - self.local_seq_no; if ack_length > 0 { - net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets (now {})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})", + self.debug_id, self.local_endpoint, self.remote_endpoint, ack_length, self.tx_buffer.len() - ack_length); } self.tx_buffer.advance(ack_length); @@ -730,8 +747,8 @@ impl<'a> TcpSocket<'a> { // Enqueue payload octets, which is guaranteed to be in order, unless we already did. if repr.payload.len() > 0 { - net_trace!("tcp:{}:{}: rx buffer: enqueueing {} octets (now {})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: rx buffer: enqueueing {} octets (now {})", + self.debug_id, self.local_endpoint, self.remote_endpoint, repr.payload.len(), self.rx_buffer.len() + repr.payload.len()); self.rx_buffer.enqueue_slice(repr.payload) } @@ -782,8 +799,8 @@ impl<'a> TcpSocket<'a> { // We transmit a SYN|ACK in the SYN-RECEIVED state. State::SynReceived => { repr.control = TcpControl::Syn; - net_trace!("tcp:{}:{}: sending SYN|ACK", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: sending SYN|ACK", + self.debug_id, self.local_endpoint, self.remote_endpoint); should_send = true; } @@ -791,8 +808,8 @@ impl<'a> TcpSocket<'a> { State::SynSent => { repr.control = TcpControl::Syn; repr.ack_number = None; - net_trace!("tcp:{}:{}: sending SYN", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: sending SYN", + self.debug_id, self.local_endpoint, self.remote_endpoint); should_send = true; } @@ -816,8 +833,8 @@ impl<'a> TcpSocket<'a> { let data = self.tx_buffer.peek(offset, size); if data.len() > 0 { // Send the extracted data. - net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets (from {})", - self.local_endpoint, self.remote_endpoint, + net_trace!("[{}]{}:{}: tx buffer: peeking at {} octets (from {})", + self.debug_id, self.local_endpoint, self.remote_endpoint, data.len(), offset); repr.seq_number += offset; repr.payload = data; @@ -832,8 +849,8 @@ impl<'a> TcpSocket<'a> { State::FinWait1 | State::LastAck => { // We should notify the other side that we've closed the transmit half // of the connection. - net_trace!("tcp:{}:{}: sending FIN|ACK", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: sending FIN|ACK", + self.debug_id, self.local_endpoint, self.remote_endpoint); repr.control = TcpControl::Fin; should_send = true; } @@ -850,15 +867,16 @@ impl<'a> TcpSocket<'a> { let ack_number = self.remote_seq_no + self.rx_buffer.len(); if !should_send && self.remote_last_ack != ack_number { // Acknowledge all data we have received, since it is all in order. - net_trace!("tcp:{}:{}: sending ACK", - self.local_endpoint, self.remote_endpoint); + net_trace!("[{}]{}:{}: sending ACK", + self.debug_id, self.local_endpoint, self.remote_endpoint); should_send = true; } if should_send { if self.retransmit.commit(timestamp) { - net_trace!("tcp:{}:{}: retransmit after {}ms", - self.local_endpoint, self.remote_endpoint, self.retransmit.delay); + net_trace!("[{}]{}:{}: retransmit after {}ms", + self.debug_id, self.local_endpoint, self.remote_endpoint, + self.retransmit.delay); } repr.ack_number = Some(ack_number); diff --git a/src/socket/udp.rs b/src/socket/udp.rs index ef39158..beb8ddf 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -111,7 +111,8 @@ impl<'a, 'b> SocketBuffer<'a, 'b> { pub struct UdpSocket<'a, 'b: 'a> { endpoint: IpEndpoint, rx_buffer: SocketBuffer<'a, 'b>, - tx_buffer: SocketBuffer<'a, 'b> + tx_buffer: SocketBuffer<'a, 'b>, + debug_id: usize } impl<'a, 'b> UdpSocket<'a, 'b> { @@ -121,10 +122,24 @@ impl<'a, 'b> UdpSocket<'a, 'b> { Socket::Udp(UdpSocket { endpoint: IpEndpoint::default(), rx_buffer: rx_buffer, - tx_buffer: tx_buffer + tx_buffer: tx_buffer, + debug_id: 0 }) } + /// Return the debug identifier. + pub fn debug_id(&self) -> usize { + self.debug_id + } + + /// Set the debug identifier. + /// + /// The debug identifier is a number printed in socket trace messages. + /// It could as well be used by the user code. + pub fn set_debug_id(&mut self, id: usize) { + self.debug_id = id + } + /// Return the bound endpoint. #[inline] pub fn endpoint(&self) -> IpEndpoint { @@ -155,8 +170,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> { let packet_buf = try!(self.tx_buffer.enqueue()); packet_buf.endpoint = endpoint; packet_buf.size = size; - net_trace!("udp:{}:{}: buffer to send {} octets", - self.endpoint, packet_buf.endpoint, packet_buf.size); + net_trace!("[{}]{}:{}: buffer to send {} octets", + self.debug_id, self.endpoint, + packet_buf.endpoint, packet_buf.size); Ok(&mut packet_buf.as_mut()[..size]) } @@ -176,8 +192,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// This function returns `Err(())` if the receive buffer is empty. pub fn recv(&mut self) -> Result<(&[u8], IpEndpoint), ()> { let packet_buf = try!(self.rx_buffer.dequeue()); - net_trace!("udp:{}:{}: receive {} buffered octets", - self.endpoint, packet_buf.endpoint, packet_buf.size); + net_trace!("[{}]{}:{}: receive {} buffered octets", + self.debug_id, self.endpoint, + packet_buf.endpoint, packet_buf.size); Ok((&packet_buf.as_ref()[..packet_buf.size], packet_buf.endpoint)) } @@ -208,8 +225,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> { packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port }; packet_buf.size = repr.payload.len(); packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload); - net_trace!("udp:{}:{}: receiving {} octets", - self.endpoint, packet_buf.endpoint, packet_buf.size); + net_trace!("[{}]{}:{}: receiving {} octets", + self.debug_id, self.endpoint, + packet_buf.endpoint, packet_buf.size); Ok(()) } @@ -217,8 +235,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> { pub fn dispatch(&mut self, _timestamp: u64, emit: &mut F) -> Result where F: FnMut(&IpRepr, &IpPayload) -> Result { let packet_buf = try!(self.tx_buffer.dequeue().map_err(|()| Error::Exhausted)); - net_trace!("udp:{}:{}: sending {} octets", - self.endpoint, packet_buf.endpoint, packet_buf.size); + net_trace!("[{}]{}:{}: sending {} octets", + self.debug_id, self.endpoint, + packet_buf.endpoint, packet_buf.size); let repr = UdpRepr { src_port: self.endpoint.port, dst_port: packet_buf.endpoint.port,