socket: remove SocketRef.

The intent was to run custom code after the user is done modifying the socket,
for example to update a (not yet existing) port->socket map in SocketSet. However
this wouldn't work, since the SocketRef would have to borrow the SocketSet at
the same time as the Socket to be able to notify the SocketSet.

I believe such indexing can be achieved by setting a "dirty" bit *before* giving
the socket to the user, then on poll() reindexing all dirty sockets. This could
even be faster: if user gets a socket multiple times between polls, it'd be reindexed
only once.
master
Dario Nieuwenhuis 2021-10-21 03:49:54 +02:00
parent c08dd8dcf6
commit f8cc1eacbe
11 changed files with 44 additions and 140 deletions

View File

@ -120,7 +120,7 @@ fn main() {
// tcp:1234: emit data
{
let mut socket = iface.get_socket::<TcpSocket>(tcp1_handle);
let socket = iface.get_socket::<TcpSocket>(tcp1_handle);
if !socket.is_open() {
socket.listen(1234).unwrap();
}
@ -140,7 +140,7 @@ fn main() {
// tcp:1235: sink data
{
let mut socket = iface.get_socket::<TcpSocket>(tcp2_handle);
let socket = iface.get_socket::<TcpSocket>(tcp2_handle);
if !socket.is_open() {
socket.listen(1235).unwrap();
}

View File

@ -55,7 +55,7 @@ fn main() {
let tcp_handle = iface.add_socket(tcp_socket);
{
let mut socket = iface.get_socket::<TcpSocket>(tcp_handle);
let socket = iface.get_socket::<TcpSocket>(tcp_handle);
socket.connect((address, port), 49500).unwrap();
}
@ -70,7 +70,7 @@ fn main() {
}
{
let mut socket = iface.get_socket::<TcpSocket>(tcp_handle);
let socket = iface.get_socket::<TcpSocket>(tcp_handle);
if socket.is_active() && !tcp_active {
debug!("connected");
} else if !socket.is_active() && tcp_active {

View File

@ -77,7 +77,7 @@ fn main() {
}
{
let mut socket = iface.get_socket::<TcpSocket>(tcp_handle);
let socket = iface.get_socket::<TcpSocket>(tcp_handle);
state = match state {
State::Connect if !socket.is_active() => {

View File

@ -78,7 +78,7 @@ fn main() {
}
{
let mut socket = iface.get_socket::<RawSocket>(raw_handle);
let socket = iface.get_socket::<RawSocket>(raw_handle);
if socket.can_recv() {
// For display purposes only - normally we wouldn't process incoming IGMP packets
@ -93,7 +93,7 @@ fn main() {
}
}
{
let mut socket = iface.get_socket::<UdpSocket>(udp_handle);
let socket = iface.get_socket::<UdpSocket>(udp_handle);
if !socket.is_open() {
socket.bind(MDNS_PORT).unwrap()
}

View File

@ -157,7 +157,7 @@ fn main() {
{
let timestamp = Instant::now();
let mut socket = iface.get_socket::<IcmpSocket>(icmp_handle);
let socket = iface.get_socket::<IcmpSocket>(icmp_handle);
if !socket.is_open() {
socket.bind(IcmpEndpoint::Ident(ident)).unwrap();
send_at = timestamp;

View File

@ -81,7 +81,7 @@ fn main() {
// udp:6969: respond "hello"
{
let mut socket = iface.get_socket::<UdpSocket>(udp_handle);
let socket = iface.get_socket::<UdpSocket>(udp_handle);
if !socket.is_open() {
socket.bind(6969).unwrap()
}
@ -109,7 +109,7 @@ fn main() {
// tcp:6969: respond "hello"
{
let mut socket = iface.get_socket::<TcpSocket>(tcp1_handle);
let socket = iface.get_socket::<TcpSocket>(tcp1_handle);
if !socket.is_open() {
socket.listen(6969).unwrap();
}
@ -124,7 +124,7 @@ fn main() {
// tcp:6970: echo with reverse
{
let mut socket = iface.get_socket::<TcpSocket>(tcp2_handle);
let socket = iface.get_socket::<TcpSocket>(tcp2_handle);
if !socket.is_open() {
socket.listen(6970).unwrap()
}
@ -168,7 +168,7 @@ fn main() {
// tcp:6971: sinkhole
{
let mut socket = iface.get_socket::<TcpSocket>(tcp3_handle);
let socket = iface.get_socket::<TcpSocket>(tcp3_handle);
if !socket.is_open() {
socket.listen(6971).unwrap();
socket.set_keep_alive(Some(Duration::from_millis(1000)));
@ -191,7 +191,7 @@ fn main() {
// tcp:6972: fountain
{
let mut socket = iface.get_socket::<TcpSocket>(tcp4_handle);
let socket = iface.get_socket::<TcpSocket>(tcp4_handle);
if !socket.is_open() {
socket.listen(6972).unwrap()
}

View File

@ -101,7 +101,7 @@ fn main() {
// udp:6969: respond "hello"
{
let mut socket = iface.get_socket::<UdpSocket>(udp_handle);
let socket = iface.get_socket::<UdpSocket>(udp_handle);
if !socket.is_open() {
socket.bind(6969).unwrap()
}

View File

@ -485,7 +485,7 @@ where
/// # Panics
/// This function may panic if the handle does not belong to this socket set
/// or the socket has the wrong type.
pub fn get_socket<T: AnySocket<'a>>(&mut self, handle: SocketHandle) -> SocketRef<T> {
pub fn get_socket<T: AnySocket<'a>>(&mut self, handle: SocketHandle) -> &mut T {
self.sockets.get(handle)
}
@ -830,7 +830,7 @@ where
let _caps = device.capabilities();
let mut emitted_any = false;
for mut socket in sockets.iter_mut() {
for socket in sockets.iter_mut() {
if !socket
.meta_mut()
.egress_permitted(cx.now, |ip_addr| inner.has_neighbor(cx, &ip_addr))
@ -1202,7 +1202,7 @@ impl<'a> InterfaceInner<'a> {
// Look for UDP sockets that will accept the UDP packet.
// If it does not accept the packet, then send an ICMP message.
for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
for udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
if !udp_socket.accepts(&IpRepr::Ipv6(ipv6_repr), &udp_repr) {
continue;
}
@ -1328,7 +1328,7 @@ impl<'a> InterfaceInner<'a> {
let mut handled_by_raw_socket = false;
// Pass every IP packet to all raw sockets we have registered.
for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) {
for raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) {
if !raw_socket.accepts(ip_repr) {
continue;
}
@ -1460,7 +1460,7 @@ impl<'a> InterfaceInner<'a> {
if udp_packet.src_port() == DHCP_SERVER_PORT
&& udp_packet.dst_port() == DHCP_CLIENT_PORT
{
if let Some(mut dhcp_socket) =
if let Some(dhcp_socket) =
sockets.iter_mut().filter_map(Dhcpv4Socket::downcast).next()
{
let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr());
@ -1639,7 +1639,7 @@ impl<'a> InterfaceInner<'a> {
let mut handled_by_icmp_socket = false;
#[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))]
for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
for icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) {
continue;
}
@ -1825,7 +1825,7 @@ impl<'a> InterfaceInner<'a> {
let mut handled_by_icmp_socket = false;
#[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))]
for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
for icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) {
if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) {
continue;
}
@ -1949,7 +1949,7 @@ impl<'a> InterfaceInner<'a> {
let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?;
let udp_payload = udp_packet.payload();
for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
for udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) {
if !udp_socket.accepts(&ip_repr, &udp_repr) {
continue;
}
@ -2006,7 +2006,7 @@ impl<'a> InterfaceInner<'a> {
let tcp_packet = TcpPacket::new_checked(ip_payload)?;
let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?;
for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) {
for tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) {
if !tcp_socket.accepts(&ip_repr, &tcp_repr) {
continue;
}
@ -2944,7 +2944,7 @@ mod test {
{
// Bind the socket to port 68
let mut socket = iface.get_socket::<UdpSocket>(socket_handle);
let socket = iface.get_socket::<UdpSocket>(socket_handle);
assert_eq!(socket.bind(68), Ok(()));
assert!(!socket.can_recv());
assert!(socket.can_send());
@ -2971,7 +2971,7 @@ mod test {
{
// Make sure the payload to the UDP packet processed by process_udp is
// appended to the bound sockets rx_buffer
let mut socket = iface.get_socket::<UdpSocket>(socket_handle);
let socket = iface.get_socket::<UdpSocket>(socket_handle);
assert!(socket.can_recv());
assert_eq!(
socket.recv(),
@ -3443,7 +3443,7 @@ mod test {
let echo_data = &[0xff; 16];
{
let mut socket = iface.get_socket::<IcmpSocket>(socket_handle);
let socket = iface.get_socket::<IcmpSocket>(socket_handle);
// Bind to the ID 0x1234
assert_eq!(socket.bind(IcmpEndpoint::Ident(ident)), Ok(()));
}
@ -3494,7 +3494,7 @@ mod test {
);
{
let mut socket = iface.get_socket::<IcmpSocket>(socket_handle);
let socket = iface.get_socket::<IcmpSocket>(socket_handle);
assert!(socket.can_recv());
assert_eq!(
socket.recv(),
@ -3856,7 +3856,7 @@ mod test {
let udp_socket_handle = iface.add_socket(udp_socket);
{
// Bind the socket to port 68
let mut socket = iface.get_socket::<UdpSocket>(udp_socket_handle);
let socket = iface.get_socket::<UdpSocket>(udp_socket_handle);
assert_eq!(socket.bind(68), Ok(()));
assert!(!socket.can_recv());
assert!(socket.can_send());
@ -3929,7 +3929,7 @@ mod test {
{
// Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP
let mut socket = iface.get_socket::<UdpSocket>(udp_socket_handle);
let socket = iface.get_socket::<UdpSocket>(udp_socket_handle);
assert!(socket.can_recv());
assert_eq!(
socket.recv(),

View File

@ -24,7 +24,6 @@ mod icmp;
mod meta;
#[cfg(feature = "socket-raw")]
mod raw;
mod ref_;
mod set;
#[cfg(feature = "socket-tcp")]
mod tcp;
@ -59,9 +58,6 @@ pub use self::dhcpv4::{Config as Dhcpv4Config, Dhcpv4Socket, Event as Dhcpv4Even
pub use self::set::{Handle as SocketHandle, Item as SocketSetItem, Set as SocketSet};
pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut};
pub use self::ref_::Ref as SocketRef;
pub(crate) use self::ref_::Session as SocketSession;
/// Gives an indication on the next time the socket should be polled.
#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
@ -144,25 +140,19 @@ impl<'a> Socket<'a> {
}
}
impl<'a> SocketSession for Socket<'a> {
fn finish(&mut self) {
dispatch_socket!(mut self, |socket| socket.finish())
}
}
/// A conversion trait for network sockets.
pub trait AnySocket<'a>: SocketSession + Sized {
fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a>>) -> Option<SocketRef<'c, Self>>;
pub trait AnySocket<'a>: Sized {
fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self>;
}
macro_rules! from_socket {
($socket:ty, $variant:ident) => {
impl<'a> AnySocket<'a> for $socket {
fn downcast<'c>(ref_: SocketRef<'c, Socket<'a>>) -> Option<SocketRef<'c, Self>> {
if let Socket::$variant(ref mut socket) = SocketRef::into_inner(ref_) {
Some(SocketRef::new(socket))
} else {
None
fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> {
#[allow(unreachable_patterns)]
match socket {
Socket::$variant(socket) => Some(socket),
_ => None,
}
}
}

View File

@ -1,87 +0,0 @@
use core::ops::{Deref, DerefMut};
/// A trait for tracking a socket usage session.
///
/// Allows implementation of custom drop logic that runs only if the socket was changed
/// in specific ways. For example, drop logic for UDP would check if the local endpoint
/// has changed, and if yes, notify the socket set.
#[doc(hidden)]
pub trait Session {
fn finish(&mut self) {}
}
#[cfg(feature = "socket-raw")]
impl<'a> Session for crate::socket::RawSocket<'a> {}
#[cfg(all(
feature = "socket-icmp",
any(feature = "proto-ipv4", feature = "proto-ipv6")
))]
impl<'a> Session for crate::socket::IcmpSocket<'a> {}
#[cfg(feature = "socket-udp")]
impl<'a> Session for crate::socket::UdpSocket<'a> {}
#[cfg(feature = "socket-tcp")]
impl<'a> Session for crate::socket::TcpSocket<'a> {}
#[cfg(feature = "socket-dhcpv4")]
impl Session for crate::socket::Dhcpv4Socket {}
/// A smart pointer to a socket.
///
/// Allows the network stack to efficiently determine if the socket state was changed in any way.
pub struct Ref<'a, T: Session + 'a> {
/// Reference to the socket.
///
/// This is almost always `Some` except when dropped in `into_inner` which removes the socket
/// reference. This properly tracks the initialization state without any additional bytes as
/// the `None` variant occupies the `0` pattern which is invalid for the reference.
socket: Option<&'a mut T>,
}
impl<'a, T: Session + 'a> Ref<'a, T> {
/// Wrap a pointer to a socket to make a smart pointer.
///
/// Calling this function is only necessary if your code is using [into_inner].
///
/// [into_inner]: #method.into_inner
pub fn new(socket: &'a mut T) -> Self {
Ref {
socket: Some(socket),
}
}
/// Unwrap a smart pointer to a socket.
///
/// The finalization code is not run. Prompt operation of the network stack depends
/// on wrapping the returned pointer back and dropping it.
///
/// Calling this function is only necessary to achieve composability if you *must*
/// map a `&mut SocketRef<'a, XSocket>` to a `&'a mut XSocket` (note the lifetimes);
/// be sure to call [new] afterwards.
///
/// [new]: #method.new
pub fn into_inner(mut ref_: Self) -> &'a mut T {
ref_.socket.take().unwrap()
}
}
impl<'a, T: Session> Deref for Ref<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// Deref is only used while the socket is still in place (into inner has not been called).
self.socket.as_ref().unwrap()
}
}
impl<'a, T: Session> DerefMut for Ref<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.socket.as_mut().unwrap()
}
}
impl<'a, T: Session> Drop for Ref<'a, T> {
fn drop(&mut self) {
if let Some(socket) = self.socket.take() {
Session::finish(socket);
}
}
}

View File

@ -3,7 +3,7 @@ use managed::ManagedSlice;
#[cfg(feature = "socket-tcp")]
use crate::socket::TcpState;
use crate::socket::{AnySocket, Socket, SocketRef};
use crate::socket::{AnySocket, Socket};
/// An item of a socket set.
///
@ -84,10 +84,11 @@ impl<'a> Set<'a> {
/// # Panics
/// This function may panic if the handle does not belong to this socket set
/// or the socket has the wrong type.
pub fn get<T: AnySocket<'a>>(&mut self, handle: Handle) -> SocketRef<T> {
pub fn get<T: AnySocket<'a>>(&mut self, handle: Handle) -> &mut T {
match self.sockets[handle.0].as_mut() {
Some(item) => T::downcast(SocketRef::new(&mut item.socket))
.expect("handle refers to a socket of a wrong type"),
Some(item) => {
T::downcast(&mut item.socket).expect("handle refers to a socket of a wrong type")
}
None => panic!("handle does not refer to a valid socket"),
}
}
@ -179,7 +180,7 @@ impl<'a> Set<'a> {
}
}
/// Iterate every socket in this set, as SocketRef.
/// Iterate every socket in this set.
pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'a> {
IterMut {
lower: self.sockets.iter_mut(),
@ -217,12 +218,12 @@ pub struct IterMut<'a, 'b: 'a> {
}
impl<'a, 'b: 'a> Iterator for IterMut<'a, 'b> {
type Item = SocketRef<'a, Socket<'b>>;
type Item = &'a mut Socket<'b>;
fn next(&mut self) -> Option<Self::Item> {
for item_opt in &mut self.lower {
if let Some(item) = item_opt.as_mut() {
return Some(SocketRef::new(&mut item.socket));
return Some(&mut item.socket);
}
}
None