Distinguish sockets by debug identifiers (socket set indexes).

This commit is contained in:
whitequark 2017-01-16 23:35:21 +00:00
parent 40716a348d
commit f126eab193
4 changed files with 125 additions and 74 deletions

View File

@ -49,7 +49,30 @@ pub enum Socket<'a, 'b: 'a> {
__Nonexhaustive
}
macro_rules! dispatch_socket {
($self_:expr, |$socket:ident [$( $mut_:tt )*]| $code:expr) => ({
match $self_ {
&$( $mut_ )* Socket::Udp(ref $( $mut_ )* $socket) => $code,
&$( $mut_ )* Socket::Tcp(ref $( $mut_ )* $socket) => $code,
&$( $mut_ )* Socket::__Nonexhaustive => unreachable!()
}
})
}
impl<'a, 'b> Socket<'a, 'b> {
/// Return the debug identifier.
pub fn debug_id(&self) -> usize {
dispatch_socket!(self, |socket []| socket.debug_id())
}
/// Set the debug identifier.
///
/// The debug identifier is a number printed in socket trace messages.
/// It could as well be used by the user code.
pub fn set_debug_id(&mut self, id: usize) {
dispatch_socket!(self, |socket [mut]| socket.set_debug_id(id))
}
/// Process a packet received from a network interface.
///
/// This function checks if the packet contained in the payload matches the socket endpoint,
@ -59,13 +82,7 @@ impl<'a, 'b> Socket<'a, 'b> {
/// This function is used internally by the networking stack.
pub fn process(&mut self, timestamp: u64, ip_repr: &IpRepr,
payload: &[u8]) -> Result<(), Error> {
match self {
&mut Socket::Udp(ref mut socket) =>
socket.process(timestamp, ip_repr, payload),
&mut Socket::Tcp(ref mut socket) =>
socket.process(timestamp, ip_repr, payload),
&mut Socket::__Nonexhaustive => unreachable!()
}
dispatch_socket!(self, |socket [mut]| socket.process(timestamp, ip_repr, payload))
}
/// Prepare a packet to be transmitted to a network interface.
@ -77,13 +94,7 @@ impl<'a, 'b> Socket<'a, 'b> {
/// This function is used internally by the networking stack.
pub fn dispatch<F, R>(&mut self, timestamp: u64, emit: &mut F) -> Result<R, Error>
where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
match self {
&mut Socket::Udp(ref mut socket) =>
socket.dispatch(timestamp, emit),
&mut Socket::Tcp(ref mut socket) =>
socket.dispatch(timestamp, emit),
&mut Socket::__Nonexhaustive => unreachable!()
}
dispatch_socket!(self, |socket [mut]| socket.dispatch(timestamp, emit))
}
}

View File

@ -29,9 +29,10 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
///
/// # Panics
/// This function panics if the storage is fixed-size (not a `Vec`) and is full.
pub fn add(&mut self, socket: Socket<'b, 'c>) -> Handle {
pub fn add(&mut self, mut socket: Socket<'b, 'c>) -> Handle {
for (index, slot) in self.sockets.iter_mut().enumerate() {
if slot.is_none() {
socket.set_debug_id(index);
*slot = Some(socket);
return Handle { index: index }
}
@ -42,8 +43,10 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
panic!("adding a socket to a full SocketSet")
}
ManagedSlice::Owned(ref mut sockets) => {
let index = sockets.len();
socket.set_debug_id(index);
sockets.push(Some(socket));
Handle { index: sockets.len() - 1 }
Handle { index: index }
}
}
}

View File

@ -247,7 +247,8 @@ pub struct TcpSocket<'a> {
remote_win_len: usize,
retransmit: Retransmit,
rx_buffer: SocketBuffer<'a>,
tx_buffer: SocketBuffer<'a>
tx_buffer: SocketBuffer<'a>,
debug_id: usize
}
impl<'a> TcpSocket<'a> {
@ -272,10 +273,24 @@ impl<'a> TcpSocket<'a> {
remote_win_len: 0,
retransmit: Retransmit::new(),
tx_buffer: tx_buffer.into(),
rx_buffer: rx_buffer.into()
rx_buffer: rx_buffer.into(),
debug_id: 0
})
}
/// Return the debug identifier.
pub fn debug_id(&self) -> usize {
self.debug_id
}
/// Set the debug identifier.
///
/// The debug identifier is a number printed in socket trace messages.
/// It could as well be used by the user code.
pub fn set_debug_id(&mut self, id: usize) {
self.debug_id = id
}
/// Return the local endpoint.
#[inline]
pub fn local_endpoint(&self) -> IpEndpoint {
@ -436,8 +451,8 @@ impl<'a> TcpSocket<'a> {
let old_length = self.tx_buffer.len();
let buffer = self.tx_buffer.enqueue(size);
if buffer.len() > 0 {
net_trace!("tcp:{}:{}: tx buffer: enqueueing {} octets (now {})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: tx buffer: enqueueing {} octets (now {})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
buffer.len(), old_length + buffer.len());
self.retransmit.reset();
}
@ -471,8 +486,8 @@ impl<'a> TcpSocket<'a> {
let buffer = self.rx_buffer.dequeue(size);
self.remote_seq_no += buffer.len();
if buffer.len() > 0 {
net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets (now {})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: rx buffer: dequeueing {} octets (now {})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
buffer.len(), old_length - buffer.len());
}
Ok(buffer)
@ -501,11 +516,13 @@ impl<'a> TcpSocket<'a> {
fn set_state(&mut self, state: State) {
if self.state != state {
if self.remote_endpoint.addr.is_unspecified() {
net_trace!("tcp:{}: state={}=>{}",
self.local_endpoint, self.state, state);
net_trace!("[{}]{}: state={}=>{}",
self.debug_id, self.local_endpoint,
self.state, state);
} else {
net_trace!("tcp:{}:{}: state={}=>{}",
self.local_endpoint, self.remote_endpoint, self.state, state);
net_trace!("[{}]{}:{}: state={}=>{}",
self.debug_id, self.local_endpoint, self.remote_endpoint,
self.state, state);
}
}
self.state = state
@ -534,25 +551,25 @@ impl<'a> TcpSocket<'a> {
match (self.state, repr) {
// The initial SYN (or whatever) cannot contain an acknowledgement.
(State::Listen, TcpRepr { ack_number: Some(_), .. }) => {
net_trace!("tcp:{}:{}: ACK received by a socket in LISTEN state",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: ACK received by a socket in LISTEN state",
self.debug_id, self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
(State::Listen, TcpRepr { ack_number: None, .. }) => (),
// An RST received in response to initial SYN is acceptable if it acknowledges
// the initial SYN.
(State::SynSent, TcpRepr { control: TcpControl::Rst, ack_number: None, .. }) => {
net_trace!("tcp:{}:{}: unacceptable RST (expecting RST|ACK) \
net_trace!("[{}]{}:{}: unacceptable RST (expecting RST|ACK) \
in response to initial SYN",
self.local_endpoint, self.remote_endpoint);
self.debug_id, self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
(State::SynSent, TcpRepr {
control: TcpControl::Rst, ack_number: Some(ack_number), ..
}) => {
if ack_number != self.local_seq_no {
net_trace!("tcp:{}:{}: unacceptable RST|ACK in response to initial SYN",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: unacceptable RST|ACK in response to initial SYN",
self.debug_id, self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
}
@ -560,8 +577,8 @@ impl<'a> TcpSocket<'a> {
(_, TcpRepr { control: TcpControl::Rst, .. }) => (),
// Every packet after the initial SYN must be an acknowledgement.
(_, TcpRepr { ack_number: None, .. }) => {
net_trace!("tcp:{}:{}: expecting an ACK",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: expecting an ACK",
self.debug_id, self.local_endpoint, self.remote_endpoint);
return Err(Error::Malformed)
}
// Every acknowledgement must be for transmitted but unacknowledged data.
@ -578,8 +595,8 @@ impl<'a> TcpSocket<'a> {
let unacknowledged = self.tx_buffer.len() + control_len;
if !(ack_number >= self.local_seq_no &&
ack_number <= (self.local_seq_no + unacknowledged)) {
net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}...{})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: unacceptable ACK ({} not in {}...{})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
ack_number, self.local_seq_no, self.local_seq_no + unacknowledged);
return Err(Error::Dropped)
}
@ -595,13 +612,13 @@ impl<'a> TcpSocket<'a> {
(_, TcpRepr { seq_number, .. }) => {
let next_remote_seq = self.remote_seq_no + self.rx_buffer.len();
if seq_number > next_remote_seq {
net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: unacceptable SEQ ({} not in {}..)",
self.debug_id, self.local_endpoint, self.remote_endpoint,
seq_number, next_remote_seq);
return Err(Error::Dropped)
} else if seq_number != next_remote_seq {
net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: duplicate SEQ ({} in ..{})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
seq_number, next_remote_seq);
// If we've seen this sequence number already but the remote end is not aware
// of that, make sure we send the acknowledgement again.
@ -620,8 +637,8 @@ impl<'a> TcpSocket<'a> {
// RSTs in SYN-RECEIVED flip the socket back to the LISTEN state.
(State::SynReceived, TcpRepr { control: TcpControl::Rst, .. }) => {
net_trace!("tcp:{}:{}: received RST",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: received RST",
self.debug_id, self.local_endpoint, self.remote_endpoint);
self.local_endpoint.addr = self.listen_address;
self.remote_endpoint = IpEndpoint::default();
self.set_state(State::Listen);
@ -630,8 +647,8 @@ impl<'a> TcpSocket<'a> {
// RSTs in any other state close the socket.
(_, TcpRepr { control: TcpControl::Rst, .. }) => {
net_trace!("tcp:{}:{}: received RST",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: received RST",
self.debug_id, self.local_endpoint, self.remote_endpoint);
self.set_state(State::Closed);
self.local_endpoint = IpEndpoint::default();
self.remote_endpoint = IpEndpoint::default();
@ -642,8 +659,8 @@ impl<'a> TcpSocket<'a> {
(State::Listen, TcpRepr {
src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, ..
}) => {
net_trace!("tcp:{}: received SYN",
self.local_endpoint);
net_trace!("[{}]{}: received SYN",
self.debug_id, self.local_endpoint);
self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), dst_port);
self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port);
// FIXME: use something more secure here
@ -710,8 +727,8 @@ impl<'a> TcpSocket<'a> {
}
_ => {
net_trace!("tcp:{}:{}: unexpected packet {}",
self.local_endpoint, self.remote_endpoint, repr);
net_trace!("[{}]{}:{}: unexpected packet {}",
self.debug_id, self.local_endpoint, self.remote_endpoint, repr);
return Err(Error::Malformed)
}
}
@ -720,8 +737,8 @@ impl<'a> TcpSocket<'a> {
if let Some(ack_number) = repr.ack_number {
let ack_length = ack_number - self.local_seq_no;
if ack_length > 0 {
net_trace!("tcp:{}:{}: tx buffer: dequeueing {} octets (now {})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: tx buffer: dequeueing {} octets (now {})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
ack_length, self.tx_buffer.len() - ack_length);
}
self.tx_buffer.advance(ack_length);
@ -730,8 +747,8 @@ impl<'a> TcpSocket<'a> {
// Enqueue payload octets, which is guaranteed to be in order, unless we already did.
if repr.payload.len() > 0 {
net_trace!("tcp:{}:{}: rx buffer: enqueueing {} octets (now {})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: rx buffer: enqueueing {} octets (now {})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
repr.payload.len(), self.rx_buffer.len() + repr.payload.len());
self.rx_buffer.enqueue_slice(repr.payload)
}
@ -782,8 +799,8 @@ impl<'a> TcpSocket<'a> {
// We transmit a SYN|ACK in the SYN-RECEIVED state.
State::SynReceived => {
repr.control = TcpControl::Syn;
net_trace!("tcp:{}:{}: sending SYN|ACK",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: sending SYN|ACK",
self.debug_id, self.local_endpoint, self.remote_endpoint);
should_send = true;
}
@ -791,8 +808,8 @@ impl<'a> TcpSocket<'a> {
State::SynSent => {
repr.control = TcpControl::Syn;
repr.ack_number = None;
net_trace!("tcp:{}:{}: sending SYN",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: sending SYN",
self.debug_id, self.local_endpoint, self.remote_endpoint);
should_send = true;
}
@ -816,8 +833,8 @@ impl<'a> TcpSocket<'a> {
let data = self.tx_buffer.peek(offset, size);
if data.len() > 0 {
// Send the extracted data.
net_trace!("tcp:{}:{}: tx buffer: peeking at {} octets (from {})",
self.local_endpoint, self.remote_endpoint,
net_trace!("[{}]{}:{}: tx buffer: peeking at {} octets (from {})",
self.debug_id, self.local_endpoint, self.remote_endpoint,
data.len(), offset);
repr.seq_number += offset;
repr.payload = data;
@ -832,8 +849,8 @@ impl<'a> TcpSocket<'a> {
State::FinWait1 | State::LastAck => {
// We should notify the other side that we've closed the transmit half
// of the connection.
net_trace!("tcp:{}:{}: sending FIN|ACK",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: sending FIN|ACK",
self.debug_id, self.local_endpoint, self.remote_endpoint);
repr.control = TcpControl::Fin;
should_send = true;
}
@ -850,15 +867,16 @@ impl<'a> TcpSocket<'a> {
let ack_number = self.remote_seq_no + self.rx_buffer.len();
if !should_send && self.remote_last_ack != ack_number {
// Acknowledge all data we have received, since it is all in order.
net_trace!("tcp:{}:{}: sending ACK",
self.local_endpoint, self.remote_endpoint);
net_trace!("[{}]{}:{}: sending ACK",
self.debug_id, self.local_endpoint, self.remote_endpoint);
should_send = true;
}
if should_send {
if self.retransmit.commit(timestamp) {
net_trace!("tcp:{}:{}: retransmit after {}ms",
self.local_endpoint, self.remote_endpoint, self.retransmit.delay);
net_trace!("[{}]{}:{}: retransmit after {}ms",
self.debug_id, self.local_endpoint, self.remote_endpoint,
self.retransmit.delay);
}
repr.ack_number = Some(ack_number);

View File

@ -111,7 +111,8 @@ impl<'a, 'b> SocketBuffer<'a, 'b> {
pub struct UdpSocket<'a, 'b: 'a> {
endpoint: IpEndpoint,
rx_buffer: SocketBuffer<'a, 'b>,
tx_buffer: SocketBuffer<'a, 'b>
tx_buffer: SocketBuffer<'a, 'b>,
debug_id: usize
}
impl<'a, 'b> UdpSocket<'a, 'b> {
@ -121,10 +122,24 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
Socket::Udp(UdpSocket {
endpoint: IpEndpoint::default(),
rx_buffer: rx_buffer,
tx_buffer: tx_buffer
tx_buffer: tx_buffer,
debug_id: 0
})
}
/// Return the debug identifier.
pub fn debug_id(&self) -> usize {
self.debug_id
}
/// Set the debug identifier.
///
/// The debug identifier is a number printed in socket trace messages.
/// It could as well be used by the user code.
pub fn set_debug_id(&mut self, id: usize) {
self.debug_id = id
}
/// Return the bound endpoint.
#[inline]
pub fn endpoint(&self) -> IpEndpoint {
@ -155,8 +170,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
let packet_buf = try!(self.tx_buffer.enqueue());
packet_buf.endpoint = endpoint;
packet_buf.size = size;
net_trace!("udp:{}:{}: buffer to send {} octets",
self.endpoint, packet_buf.endpoint, packet_buf.size);
net_trace!("[{}]{}:{}: buffer to send {} octets",
self.debug_id, self.endpoint,
packet_buf.endpoint, packet_buf.size);
Ok(&mut packet_buf.as_mut()[..size])
}
@ -176,8 +192,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
/// This function returns `Err(())` if the receive buffer is empty.
pub fn recv(&mut self) -> Result<(&[u8], IpEndpoint), ()> {
let packet_buf = try!(self.rx_buffer.dequeue());
net_trace!("udp:{}:{}: receive {} buffered octets",
self.endpoint, packet_buf.endpoint, packet_buf.size);
net_trace!("[{}]{}:{}: receive {} buffered octets",
self.debug_id, self.endpoint,
packet_buf.endpoint, packet_buf.size);
Ok((&packet_buf.as_ref()[..packet_buf.size], packet_buf.endpoint))
}
@ -208,8 +225,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
packet_buf.size = repr.payload.len();
packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload);
net_trace!("udp:{}:{}: receiving {} octets",
self.endpoint, packet_buf.endpoint, packet_buf.size);
net_trace!("[{}]{}:{}: receiving {} octets",
self.debug_id, self.endpoint,
packet_buf.endpoint, packet_buf.size);
Ok(())
}
@ -217,8 +235,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
pub fn dispatch<F, R>(&mut self, _timestamp: u64, emit: &mut F) -> Result<R, Error>
where F: FnMut(&IpRepr, &IpPayload) -> Result<R, Error> {
let packet_buf = try!(self.tx_buffer.dequeue().map_err(|()| Error::Exhausted));
net_trace!("udp:{}:{}: sending {} octets",
self.endpoint, packet_buf.endpoint, packet_buf.size);
net_trace!("[{}]{}:{}: sending {} octets",
self.debug_id, self.endpoint,
packet_buf.endpoint, packet_buf.size);
let repr = UdpRepr {
src_port: self.endpoint.port,
dst_port: packet_buf.endpoint.port,