Implement a SocketRef smart pointer to detect state changes.

v0.7.x
Egor Karavaev 2017-09-10 23:18:12 +03:00 committed by whitequark
parent 52600cd521
commit 19b1b764ed
8 changed files with 146 additions and 88 deletions

View File

@ -12,8 +12,7 @@ use std::os::unix::io::AsRawFd;
use smoltcp::phy::wait as phy_wait;
use smoltcp::wire::{EthernetAddress, Ipv4Address, IpAddress, IpCidr};
use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
use smoltcp::socket::{AsSocket, SocketSet};
use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
fn main() {
utils::setup_logging("");
@ -50,14 +49,14 @@ fn main() {
let tcp_handle = sockets.add(tcp_socket);
{
let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket();
let mut socket = sockets.get::<TcpSocket>(tcp_handle);
socket.connect((address, port), 49500).unwrap();
}
let mut tcp_active = false;
loop {
{
let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket();
let mut socket = sockets.get::<TcpSocket>(tcp_handle);
if socket.is_active() && !tcp_active {
debug!("connected");
} else if !socket.is_active() && tcp_active {

View File

@ -19,8 +19,7 @@ use core::str;
use smoltcp::phy::Loopback;
use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
use smoltcp::socket::{AsSocket, SocketSet};
use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer};
#[cfg(not(feature = "std"))]
mod mock {
@ -124,7 +123,7 @@ fn main() {
let mut done = false;
while !done && clock.elapsed() < 10_000 {
{
let socket: &mut TcpSocket = socket_set.get_mut(server_handle).as_socket();
let mut socket = socket_set.get::<TcpSocket>(server_handle);
if !socket.is_active() && !socket.is_listening() {
if !did_listen {
debug!("listening");
@ -141,7 +140,7 @@ fn main() {
}
{
let socket: &mut TcpSocket = socket_set.get_mut(client_handle).as_socket();
let mut socket = socket_set.get::<TcpSocket>(client_handle);
if !socket.is_open() {
if !did_connect {
debug!("connecting");

View File

@ -16,8 +16,7 @@ use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress, IpCidr,
Ipv4Address, Ipv4Packet, Ipv4Repr,
Icmpv4Repr, Icmpv4Packet};
use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
use smoltcp::socket::{AsSocket, SocketSet};
use smoltcp::socket::{RawSocket, RawSocketBuffer, RawPacketBuffer};
use smoltcp::socket::{SocketSet, RawSocket, RawSocketBuffer, RawPacketBuffer};
use std::collections::HashMap;
use byteorder::{ByteOrder, NetworkEndian};
@ -75,7 +74,7 @@ fn main() {
loop {
{
let socket: &mut RawSocket = sockets.get_mut(raw_handle).as_socket();
let mut socket = sockets.get::<RawSocket>(raw_handle);
let timestamp = Instant::now().duration_since(startup_time);
let timestamp_us = (timestamp.as_secs() * 1000000) +

View File

@ -13,7 +13,7 @@ use std::os::unix::io::AsRawFd;
use smoltcp::phy::wait as phy_wait;
use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr};
use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface};
use smoltcp::socket::{AsSocket, SocketSet};
use smoltcp::socket::SocketSet;
use smoltcp::socket::{UdpSocket, UdpSocketBuffer, UdpPacketBuffer};
use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
@ -70,7 +70,7 @@ fn main() {
loop {
// udp:6969: respond "hello"
{
let socket: &mut UdpSocket = sockets.get_mut(udp_handle).as_socket();
let mut socket = sockets.get::<UdpSocket>(udp_handle);
if !socket.is_open() {
socket.bind(6969).unwrap()
}
@ -93,7 +93,7 @@ fn main() {
// tcp:6969: respond "hello"
{
let socket: &mut TcpSocket = sockets.get_mut(tcp1_handle).as_socket();
let mut socket = sockets.get::<TcpSocket>(tcp1_handle);
if !socket.is_open() {
socket.listen(6969).unwrap();
}
@ -108,7 +108,7 @@ fn main() {
// tcp:6970: echo with reverse
{
let socket: &mut TcpSocket = sockets.get_mut(tcp2_handle).as_socket();
let mut socket = sockets.get::<TcpSocket>(tcp2_handle);
if !socket.is_open() {
socket.listen(6970).unwrap()
}
@ -145,7 +145,7 @@ fn main() {
// tcp:6971: sinkhole
{
let socket: &mut TcpSocket = sockets.get_mut(tcp3_handle).as_socket();
let mut socket = sockets.get::<TcpSocket>(tcp3_handle);
if !socket.is_open() {
socket.listen(6971).unwrap();
socket.set_keep_alive(Some(1000));
@ -165,7 +165,7 @@ fn main() {
// tcp:6972: fountain
{
let socket: &mut TcpSocket = sockets.get_mut(tcp4_handle).as_socket();
let mut socket = sockets.get::<TcpSocket>(tcp4_handle);
if !socket.is_open() {
socket.listen(6972).unwrap()
}

View File

@ -13,7 +13,7 @@ use wire::{Ipv4Packet, Ipv4Repr};
use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable};
#[cfg(feature = "socket-udp")] use wire::{UdpPacket, UdpRepr};
#[cfg(feature = "socket-tcp")] use wire::{TcpPacket, TcpRepr, TcpControl};
use socket::{Socket, SocketSet, AsSocket};
use socket::{Socket, SocketSet, AnySocket};
#[cfg(feature = "socket-raw")] use socket::RawSocket;
#[cfg(feature = "socket-udp")] use socket::UdpSocket;
#[cfg(feature = "socket-tcp")] use socket::TcpSocket;
@ -195,29 +195,29 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
let mut caps = self.device.capabilities();
caps.max_transmission_unit -= EthernetFrame::<&[u8]>::header_len();
for socket in sockets.iter_mut() {
for mut socket in sockets.iter_mut() {
let mut device_result = Ok(());
let socket_result =
match socket {
match *socket {
#[cfg(feature = "socket-raw")]
&mut Socket::Raw(ref mut socket) =>
Socket::Raw(ref mut socket) =>
socket.dispatch(|response| {
device_result = self.dispatch(timestamp, Packet::Raw(response));
device_result
}, &caps.checksum),
#[cfg(feature = "socket-udp")]
&mut Socket::Udp(ref mut socket) =>
Socket::Udp(ref mut socket) =>
socket.dispatch(|response| {
device_result = self.dispatch(timestamp, Packet::Udp(response));
device_result
}),
#[cfg(feature = "socket-tcp")]
&mut Socket::Tcp(ref mut socket) =>
Socket::Tcp(ref mut socket) =>
socket.dispatch(timestamp, &caps, |response| {
device_result = self.dispatch(timestamp, Packet::Tcp(response));
device_result
}),
&mut Socket::__Nonexhaustive(_) => unreachable!()
Socket::__Nonexhaustive(_) => unreachable!()
};
match (device_result, socket_result) {
(Err(Error::Unaddressable), _) => break, // no one to transmit to
@ -323,8 +323,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
// Pass every IP packet to all raw sockets we have registered.
#[cfg(feature = "socket-raw")]
for raw_socket in sockets.iter_mut().filter_map(
<Socket as AsSocket<RawSocket>>::try_as_socket) {
for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) {
if !raw_socket.accepts(&ip_repr) { continue }
match raw_socket.process(&ip_repr, ip_payload, &checksum_caps) {
@ -415,8 +414,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
let checksum_caps = self.device.capabilities().checksum;
let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?;
for udp_socket in sockets.iter_mut().filter_map(
<Socket as AsSocket<UdpSocket>>::try_as_socket) {
for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
if !udp_socket.accepts(&ip_repr, &udp_repr) { continue }
match udp_socket.process(&ip_repr, &udp_repr) {
@ -458,8 +456,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
let checksum_caps = self.device.capabilities().checksum;
let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &checksum_caps)?;
for tcp_socket in sockets.iter_mut().filter_map(
<Socket as AsSocket<TcpSocket>>::try_as_socket) {
for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) {
if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue }
match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) {

View File

@ -17,6 +17,7 @@ use wire::IpRepr;
#[cfg(feature = "socket-udp")] mod udp;
#[cfg(feature = "socket-tcp")] mod tcp;
mod set;
mod ref_;
#[cfg(feature = "socket-raw")]
pub use self::raw::{PacketBuffer as RawPacketBuffer,
@ -36,19 +37,19 @@ pub use self::tcp::{SocketBuffer as TcpSocketBuffer,
pub use self::set::{Set as SocketSet, Item as SocketSetItem, Handle as SocketHandle};
pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut};
pub use self::ref_::Ref as SocketRef;
pub(crate) use self::ref_::Session as SocketSession;
/// A network socket.
///
/// This enumeration abstracts the various types of sockets based on the IP protocol.
/// To downcast a `Socket` value down to a concrete socket, use
/// the [AsSocket](trait.AsSocket.html) trait, and call e.g. `socket.as_socket::<UdpSocket<_>>()`.
/// To downcast a `Socket` value to a concrete socket, use the [AnySocket] trait,
/// e.g. to get `UdpSocket`, call `UdpSocket::downcast(socket)`.
///
/// The `process` and `dispatch` functions are fundamentally asymmetric and thus differ in
/// their use of the [trait PacketRepr](trait.PacketRepr.html). When `process` is called,
/// the packet length is already known and no allocation is required; on the other hand,
/// `process` would have to downcast a `&PacketRepr` to e.g. an `&UdpRepr` through `Any`,
/// which is rather inelegant. Conversely, when `dispatch` is called, the packet length is
/// not yet known and the packet storage has to be allocated; but the `&PacketRepr` is sufficient
/// since the lower layers treat the packet as an opaque octet sequence.
/// It is usually more convenient to use [SocketSet::get] instead.
///
/// [AnySocket]: trait.AnySocket.html
/// [SocketSet::get]: struct.SocketSet.html#method.get
#[derive(Debug)]
pub enum Socket<'a, 'b: 'a> {
#[cfg(feature = "socket-raw")]
@ -90,40 +91,37 @@ impl<'a, 'b> Socket<'a, 'b> {
}
}
/// A conversion trait for network sockets.
///
/// This trait is used to concisely downcast [Socket](trait.Socket.html) values to their
/// concrete types.
pub trait AsSocket<T> {
fn as_socket(&mut self) -> &mut T;
fn try_as_socket(&mut self) -> Option<&mut T>;
impl<'a, 'b> SocketSession for Socket<'a, 'b> {
fn finish(&mut self) {
dispatch_socket!(self, |socket [mut]| socket.finish())
}
}
macro_rules! as_socket {
($socket:ty, $variant:ident) => {
impl<'a, 'b> AsSocket<$socket> for Socket<'a, 'b> {
fn as_socket(&mut self) -> &mut $socket {
match self {
&mut Socket::$variant(ref mut socket) => socket,
_ => panic!(concat!(".as_socket::<",
stringify!($socket),
"> called on wrong socket type"))
}
}
/// A conversion trait for network sockets.
pub trait AnySocket<'a, 'b>: SocketSession + Sized {
fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a, 'b>>) ->
Option<SocketRef<'c, Self>>;
}
fn try_as_socket(&mut self) -> Option<&mut $socket> {
match self {
&mut Socket::$variant(ref mut socket) => Some(socket),
_ => None,
}
macro_rules! from_socket {
($socket:ty, $variant:ident) => {
impl<'a, 'b> AnySocket<'a, 'b> for $socket {
fn downcast<'c>(ref_: SocketRef<'c, Socket<'a, 'b>>) ->
Option<SocketRef<'c, Self>> {
SocketRef::map(ref_, |socket| {
match *socket {
Socket::$variant(ref mut socket) => Some(socket),
_ => None,
}
})
}
}
}
}
#[cfg(feature = "socket-raw")]
as_socket!(RawSocket<'a, 'b>, Raw);
from_socket!(RawSocket<'a, 'b>, Raw);
#[cfg(feature = "socket-udp")]
as_socket!(UdpSocket<'a, 'b>, Udp);
from_socket!(UdpSocket<'a, 'b>, Udp);
#[cfg(feature = "socket-tcp")]
as_socket!(TcpSocket<'a>, Tcp);
from_socket!(TcpSocket<'a>, Tcp);

73
src/socket/ref_.rs Normal file
View File

@ -0,0 +1,73 @@
use core::ops::{Deref, DerefMut};
#[cfg(feature = "socket-raw")]
use socket::RawSocket;
#[cfg(feature = "socket-udp")]
use socket::UdpSocket;
#[cfg(feature = "socket-tcp")]
use socket::TcpSocket;
/// A trait for tracking a socket usage session.
///
/// Allows implementation of custom drop logic that runs only if the socket was changed
/// in specific ways. For example, drop logic for UDP would check if the local endpoint
/// has changed, and if yes, notify the socket set.
#[doc(hidden)]
pub trait Session {
fn finish(&mut self) {}
}
#[cfg(feature = "socket-raw")]
impl<'a, 'b> Session for RawSocket<'a, 'b> {}
#[cfg(feature = "socket-udp")]
impl<'a, 'b> Session for UdpSocket<'a, 'b> {}
#[cfg(feature = "socket-tcp")]
impl<'a> Session for TcpSocket<'a> {}
/// A smart pointer to a socket.
///
/// Allows the network stack to efficiently determine if the socket state was changed in any way.
pub struct Ref<'a, T: Session + 'a> {
socket: &'a mut T,
consumed: bool,
}
impl<'a, T: Session> Ref<'a, T> {
pub(crate) fn new(socket: &'a mut T) -> Self {
Ref { socket, consumed: false }
}
}
impl<'a, T: Session + 'a> Ref<'a, T> {
pub(crate) fn map<U, F>(mut ref_: Self, f: F) -> Option<Ref<'a, U>>
where U: Session + 'a, F: FnOnce(&'a mut T) -> Option<&'a mut U> {
if let Some(socket) = f(ref_.socket) {
ref_.consumed = true;
Some(Ref::new(socket))
} else {
None
}
}
}
impl<'a, T: Session> Deref for Ref<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.socket
}
}
impl<'a, T: Session> DerefMut for Ref<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.socket
}
}
impl<'a, T: Session> Drop for Ref<'a, T> {
fn drop(&mut self) {
if !self.consumed {
Session::finish(self.socket);
}
}
}

View File

@ -1,7 +1,7 @@
use core::{fmt, slice};
use managed::ManagedSlice;
use super::Socket;
use super::{Socket, SocketRef, AnySocket};
#[cfg(feature = "socket-tcp")] use super::TcpState;
/// An item of a socket set.
@ -28,7 +28,7 @@ impl fmt::Display for Handle {
}
}
/// An extensible set of sockets, with stable numeric identifiers.
/// An extensible set of sockets.
///
/// The lifetimes `'b` and `'c` are used when storing a `Socket<'b, 'c>`.
#[derive(Debug)]
@ -79,26 +79,19 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
}
}
/// Get a socket from the set by its handle.
///
/// # Panics
/// This function may panic if the handle does not belong to this socket set.
pub fn get(&self, handle: Handle) -> &Socket<'b, 'c> {
&self.sockets[handle.0]
.as_ref()
.expect("handle does not refer to a valid socket")
.socket
}
/// Get a socket from the set by its handle, as mutable.
///
/// # Panics
/// This function may panic if the handle does not belong to this socket set.
pub fn get_mut(&mut self, handle: Handle) -> &mut Socket<'b, 'c> {
&mut self.sockets[handle.0]
.as_mut()
.expect("handle does not refer to a valid socket")
.socket
/// This function may panic if the handle does not belong to this socket set
/// or the socket has the wrong type.
pub fn get<T: AnySocket<'b, 'c>>(&mut self, handle: Handle) -> SocketRef<T> {
match self.sockets[handle.0].as_mut() {
Some(item) => {
T::downcast(SocketRef::new(&mut item.socket))
.expect("handle refers to a socket of a wrong type")
}
None => panic!("handle does not refer to a valid socket")
}
}
/// Remove a socket from the set, without changing its state.
@ -175,7 +168,7 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> {
Iter { lower: self.sockets.iter() }
}
/// Iterate every socket in this set, as mutable.
/// Iterate every socket in this set, as SocketRef.
pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'b, 'c> {
IterMut { lower: self.sockets.iter_mut() }
}
@ -207,16 +200,16 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for Iter<'a, 'b, 'c> {
/// This struct is created by the [iter_mut](struct.SocketSet.html#method.iter_mut)
/// on [socket sets](struct.SocketSet.html).
pub struct IterMut<'a, 'b: 'a, 'c: 'a + 'b> {
lower: slice::IterMut<'a, Option<Item<'b, 'c>>>
lower: slice::IterMut<'a, Option<Item<'b, 'c>>>,
}
impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for IterMut<'a, 'b, 'c> {
type Item = &'a mut Socket<'b, 'c>;
type Item = SocketRef<'a, Socket<'b, 'c>>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(item_opt) = self.lower.next() {
if let Some(item) = item_opt.as_mut() {
return Some(&mut item.socket)
return Some(SocketRef::new(&mut item.socket))
}
}
None