Rework TcpSocket::{send,recv} to remove need for precomputing size.

Now, these functions give you the largest contiguous slice they can
grab, and you return however much you took from it.
v0.7.x
whitequark 2017-10-31 19:24:54 +00:00
parent 0b22943ddd
commit fe6b04a29a
5 changed files with 121 additions and 78 deletions

View File

@ -66,8 +66,8 @@ fn main() {
tcp_active = socket.is_active();
if socket.may_recv() {
let data = {
let mut data = socket.recv(128).unwrap().to_owned();
let data = socket.recv(|data| {
let mut data = data.to_owned();
if data.len() > 0 {
debug!("recv data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@ -75,8 +75,8 @@ fn main() {
data.reverse();
data.extend(b"\n");
}
data
};
(data.len(), data)
}).unwrap();
if socket.can_send() && data.len() > 0 {
debug!("send data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));

View File

@ -133,7 +133,9 @@ fn main() {
}
if socket.can_recv() {
debug!("got {:?}", str::from_utf8(socket.recv(32).unwrap()).unwrap());
debug!("got {:?}", socket.recv(|buffer| {
(buffer.len(), str::from_utf8(buffer).unwrap())
}));
socket.close();
done = true;
}

View File

@ -121,8 +121,8 @@ fn main() {
tcp_6970_active = socket.is_active();
if socket.may_recv() {
let data = {
let mut data = socket.recv(128).unwrap().to_owned();
let data = socket.recv(|buffer| {
let mut data = buffer.to_owned();
if data.len() > 0 {
debug!("tcp:6970 recv data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@ -130,8 +130,8 @@ fn main() {
data.reverse();
data.extend(b"\n");
}
data
};
(data.len(), data)
}).unwrap();
if socket.can_send() && data.len() > 0 {
debug!("tcp:6970 send data: {:?}",
str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)"));
@ -153,11 +153,12 @@ fn main() {
}
if socket.may_recv() {
if let Ok(data) = socket.recv(65535) {
if data.len() > 0 {
debug!("tcp:6971 recv {:?} octets", data.len());
socket.recv(|buffer| {
if buffer.len() > 0 {
debug!("tcp:6971 recv {:?} octets", buffer.len());
}
}
(buffer.len(), ())
}).unwrap();
} else if socket.may_send() {
socket.close();
}
@ -171,14 +172,15 @@ fn main() {
}
if socket.may_send() {
if let Ok(data) = socket.send(65535) {
socket.send(|data| {
if data.len() > 0 {
debug!("tcp:6972 send {:?} octets", data.len());
for (i, b) in data.iter_mut().enumerate() {
*b = (i % 256) as u8;
}
}
}
(data.len(), ())
}).unwrap();
}
}

View File

@ -593,15 +593,8 @@ impl<'a> TcpSocket<'a> {
!self.rx_buffer.is_empty()
}
/// Enqueue a sequence of octets to be sent, and return a pointer to it.
///
/// This function may return a slice smaller than the requested size in case
/// there is not enough contiguous free space in the transmit buffer, down to
/// an empty slice.
///
/// This function returns `Err(Error::Illegal) if the transmit half of
/// the connection is not open; see [may_send](#method.may_send).
pub fn send(&mut self, size: usize) -> Result<&mut [u8]> {
fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
if !self.may_send() { return Err(Error::Illegal) }
// The connection might have been idle for a long time, and so remote_last_ts
@ -610,14 +603,26 @@ impl<'a> TcpSocket<'a> {
if self.tx_buffer.is_empty() { self.remote_last_ts = None }
let _old_length = self.tx_buffer.len();
let buffer = self.tx_buffer.enqueue_many(size);
if buffer.len() > 0 {
let (size, result) = f(&mut self.tx_buffer);
if size > 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
buffer.len(), _old_length + buffer.len());
size, _old_length + size);
}
Ok(buffer)
Ok(result)
}
/// Call `f` with the largest contiguous slice of octets in the transmit buffer,
/// and enqueue the amount of elements returned by `f`.
///
/// This function returns `Err(Error::Illegal) if the transmit half of
/// the connection is not open; see [may_send](#method.may_send).
pub fn send<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut [u8]) -> (usize, R) {
self.send_impl(|tx_buffer| {
tx_buffer.enqueue_many_with(f)
})
}
/// Enqueue a sequence of octets to be sent, and fill it from a slice.
@ -627,46 +632,42 @@ impl<'a> TcpSocket<'a> {
///
/// See also [send](#method.send).
pub fn send_slice(&mut self, data: &[u8]) -> Result<usize> {
if !self.may_send() { return Err(Error::Illegal) }
// See above.
if self.tx_buffer.is_empty() { self.remote_last_ts = None }
let _old_length = self.tx_buffer.len();
let enqueued = self.tx_buffer.enqueue_slice(data);
if enqueued != 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
enqueued, _old_length + enqueued);
}
Ok(enqueued)
self.send_impl(|tx_buffer| {
let size = tx_buffer.enqueue_slice(data);
(size, size)
})
}
/// Dequeue a sequence of received octets, and return a pointer to it.
///
/// This function may return a slice smaller than the requested size in case
/// there are not enough octets queued in the receive buffer, down to
/// an empty slice.
///
/// This function returns `Err(Error::Illegal) if the receive half of
/// the connection is not open; see [may_recv](#method.may_recv).
pub fn recv(&mut self, size: usize) -> Result<&[u8]> {
pub fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) {
// We may have received some data inside the initial SYN, but until the connection
// is fully open we must not dequeue any data, as it may be overwritten by e.g.
// another (stale) SYN.
// another (stale) SYN. (We do not support TCP Fast Open.)
if !self.may_recv() { return Err(Error::Illegal) }
let _old_length = self.rx_buffer.len();
let buffer = self.rx_buffer.dequeue_many(size);
self.remote_seq_no += buffer.len();
if buffer.len() > 0 {
let (size, result) = f(&mut self.rx_buffer);
self.remote_seq_no += size;
if size > 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
buffer.len(), _old_length - buffer.len());
size, _old_length - size);
}
Ok(buffer)
Ok(result)
}
/// Call `f` with the largest contiguous slice of octets in the receive buffer,
/// and dequeue the amount of elements returned by `f`.
///
/// This function returns `Err(Error::Illegal) if the receive half of
/// the connection is not open; see [may_recv](#method.may_recv).
pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R>
where F: FnOnce(&'b mut [u8]) -> (usize, R) {
self.recv_impl(|rx_buffer| {
rx_buffer.dequeue_many_with(f)
})
}
/// Dequeue a sequence of received octets, and fill a slice from it.
@ -676,19 +677,10 @@ impl<'a> TcpSocket<'a> {
///
/// See also [recv](#method.recv).
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize> {
// See recv() above.
if !self.may_recv() { return Err(Error::Illegal) }
let _old_length = self.rx_buffer.len();
let dequeued = self.rx_buffer.dequeue_slice(data);
self.remote_seq_no += dequeued;
if dequeued > 0 {
#[cfg(any(test, feature = "verbose"))]
net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})",
self.handle, self.local_endpoint, self.remote_endpoint,
dequeued, _old_length - dequeued);
}
Ok(dequeued)
self.recv_impl(|rx_buffer| {
let size = rx_buffer.dequeue_slice(data);
(size, size)
})
}
/// Peek at a sequence of received octets without removing them from
@ -3145,7 +3137,10 @@ mod test {
..RECV_TEMPL
}]);
recv!(s, time 0, Err(Error::Exhausted));
assert_eq!(s.recv(3), Ok(&b"abc"[..]));
s.recv(|buffer| {
assert_eq!(&buffer[..3], b"abc");
(3, ())
}).unwrap();
recv!(s, time 0, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1 + 6),
@ -3153,7 +3148,10 @@ mod test {
..RECV_TEMPL
}));
recv!(s, time 0, Err(Error::Exhausted));
assert_eq!(s.recv(3), Ok(&b"def"[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"def");
(buffer.len(), ())
}).unwrap();
recv!(s, time 0, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1 + 6),
@ -3457,7 +3455,10 @@ mod test {
ack_number: Some(REMOTE_SEQ + 1),
..RECV_TEMPL
})));
assert_eq!(s.recv(10), Ok(&b""[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"");
(buffer.len(), ())
}).unwrap();
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
@ -3469,11 +3470,14 @@ mod test {
window_len: 58,
..RECV_TEMPL
})));
assert_eq!(s.recv(10), Ok(&b"abcdef"[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"abcdef");
(buffer.len(), ())
}).unwrap();
}
#[test]
fn test_buffer_wraparound() {
fn test_buffer_wraparound_rx() {
let mut s = socket_established();
s.rx_buffer = SocketBuffer::new(vec![0; 6]);
s.assembler = Assembler::new(s.rx_buffer.capacity());
@ -3483,7 +3487,10 @@ mod test {
payload: &b"abc"[..],
..SEND_TEMPL
});
assert_eq!(s.recv(3), Ok(&b"abc"[..]));
s.recv(|buffer| {
assert_eq!(buffer, b"abc");
(buffer.len(), ())
}).unwrap();
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1 + 3,
ack_number: Some(LOCAL_SEQ + 1),
@ -3495,6 +3502,38 @@ mod test {
assert_eq!(data, &b"defghi"[..]);
}
#[test]
fn test_buffer_wraparound_tx() {
let mut s = socket_established();
s.tx_buffer = SocketBuffer::new(vec![0; 6]);
assert_eq!(s.send_slice(b"abc"), Ok(3));
recv!(s, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1,
ack_number: Some(REMOTE_SEQ + 1),
payload: &b"abc"[..],
..RECV_TEMPL
}));
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1 + 3),
..SEND_TEMPL
});
assert_eq!(s.send_slice(b"defghi"), Ok(6));
recv!(s, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1 + 3,
ack_number: Some(REMOTE_SEQ + 1),
payload: &b"def"[..],
..RECV_TEMPL
}));
// "defghi" not contiguous in tx buffer
recv!(s, Ok(TcpRepr {
seq_number: LOCAL_SEQ + 1 + 3 + 3,
ack_number: Some(REMOTE_SEQ + 1),
payload: &b"ghi"[..],
..RECV_TEMPL
}));
}
// =========================================================================================//
// Tests for packet filtering.
// =========================================================================================//

View File

@ -195,8 +195,8 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
Ok((&packet_buf.as_ref(), packet_buf.endpoint))
}
/// Dequeue a packet received from a remote endpoint, and return the endpoint as well
/// as copy the payload into the given slice.
/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
/// and return the amount of octets copied as well as the endpoint.
///
/// See also [recv](#method.recv).
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {