From f8cc1eacbe2c78c5bd03b17f796d7636df18ccc6 Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Thu, 21 Oct 2021 03:49:54 +0200 Subject: [PATCH] socket: remove SocketRef. The intent was to run custom code after the user is done modifying the socket, for example to update a (not yet existing) port->socket map in SocketSet. However this wouldn't work, since the SocketRef would have to borrow the SocketSet at the same time as the Socket to be able to notify the SocketSet. I believe such indexing can be achieved by setting a "dirty" bit *before* giving the socket to the user, then on poll() reindexing all dirty sockets. This could even be faster: if user gets a socket multiple times between polls, it'd be reindexed only once. --- examples/benchmark.rs | 4 +- examples/client.rs | 4 +- examples/httpclient.rs | 2 +- examples/multicast.rs | 4 +- examples/ping.rs | 2 +- examples/server.rs | 10 ++--- examples/sixlowpan.rs | 2 +- src/iface/interface.rs | 30 +++++++-------- src/socket/mod.rs | 24 ++++-------- src/socket/ref_.rs | 87 ------------------------------------------ src/socket/set.rs | 15 ++++---- 11 files changed, 44 insertions(+), 140 deletions(-) delete mode 100644 src/socket/ref_.rs diff --git a/examples/benchmark.rs b/examples/benchmark.rs index a07b9a6..9ba3d29 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -120,7 +120,7 @@ fn main() { // tcp:1234: emit data { - let mut socket = iface.get_socket::(tcp1_handle); + let socket = iface.get_socket::(tcp1_handle); if !socket.is_open() { socket.listen(1234).unwrap(); } @@ -140,7 +140,7 @@ fn main() { // tcp:1235: sink data { - let mut socket = iface.get_socket::(tcp2_handle); + let socket = iface.get_socket::(tcp2_handle); if !socket.is_open() { socket.listen(1235).unwrap(); } diff --git a/examples/client.rs b/examples/client.rs index 5cfa501..f096174 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -55,7 +55,7 @@ fn main() { let tcp_handle = iface.add_socket(tcp_socket); { - let mut socket = iface.get_socket::(tcp_handle); + let socket = iface.get_socket::(tcp_handle); socket.connect((address, port), 49500).unwrap(); } @@ -70,7 +70,7 @@ fn main() { } { - let mut socket = iface.get_socket::(tcp_handle); + let socket = iface.get_socket::(tcp_handle); if socket.is_active() && !tcp_active { debug!("connected"); } else if !socket.is_active() && tcp_active { diff --git a/examples/httpclient.rs b/examples/httpclient.rs index 73f08c3..0ced70f 100644 --- a/examples/httpclient.rs +++ b/examples/httpclient.rs @@ -77,7 +77,7 @@ fn main() { } { - let mut socket = iface.get_socket::(tcp_handle); + let socket = iface.get_socket::(tcp_handle); state = match state { State::Connect if !socket.is_active() => { diff --git a/examples/multicast.rs b/examples/multicast.rs index 9bc90cd..e3f2fc6 100644 --- a/examples/multicast.rs +++ b/examples/multicast.rs @@ -78,7 +78,7 @@ fn main() { } { - let mut socket = iface.get_socket::(raw_handle); + let socket = iface.get_socket::(raw_handle); if socket.can_recv() { // For display purposes only - normally we wouldn't process incoming IGMP packets @@ -93,7 +93,7 @@ fn main() { } } { - let mut socket = iface.get_socket::(udp_handle); + let socket = iface.get_socket::(udp_handle); if !socket.is_open() { socket.bind(MDNS_PORT).unwrap() } diff --git a/examples/ping.rs b/examples/ping.rs index 4b14c14..b4e9038 100644 --- a/examples/ping.rs +++ b/examples/ping.rs @@ -157,7 +157,7 @@ fn main() { { let timestamp = Instant::now(); - let mut socket = iface.get_socket::(icmp_handle); + let socket = iface.get_socket::(icmp_handle); if !socket.is_open() { socket.bind(IcmpEndpoint::Ident(ident)).unwrap(); send_at = timestamp; diff --git a/examples/server.rs b/examples/server.rs index 7e1c32a..c2abb18 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -81,7 +81,7 @@ fn main() { // udp:6969: respond "hello" { - let mut socket = iface.get_socket::(udp_handle); + let socket = iface.get_socket::(udp_handle); if !socket.is_open() { socket.bind(6969).unwrap() } @@ -109,7 +109,7 @@ fn main() { // tcp:6969: respond "hello" { - let mut socket = iface.get_socket::(tcp1_handle); + let socket = iface.get_socket::(tcp1_handle); if !socket.is_open() { socket.listen(6969).unwrap(); } @@ -124,7 +124,7 @@ fn main() { // tcp:6970: echo with reverse { - let mut socket = iface.get_socket::(tcp2_handle); + let socket = iface.get_socket::(tcp2_handle); if !socket.is_open() { socket.listen(6970).unwrap() } @@ -168,7 +168,7 @@ fn main() { // tcp:6971: sinkhole { - let mut socket = iface.get_socket::(tcp3_handle); + let socket = iface.get_socket::(tcp3_handle); if !socket.is_open() { socket.listen(6971).unwrap(); socket.set_keep_alive(Some(Duration::from_millis(1000))); @@ -191,7 +191,7 @@ fn main() { // tcp:6972: fountain { - let mut socket = iface.get_socket::(tcp4_handle); + let socket = iface.get_socket::(tcp4_handle); if !socket.is_open() { socket.listen(6972).unwrap() } diff --git a/examples/sixlowpan.rs b/examples/sixlowpan.rs index 4773914..c653d66 100644 --- a/examples/sixlowpan.rs +++ b/examples/sixlowpan.rs @@ -101,7 +101,7 @@ fn main() { // udp:6969: respond "hello" { - let mut socket = iface.get_socket::(udp_handle); + let socket = iface.get_socket::(udp_handle); if !socket.is_open() { socket.bind(6969).unwrap() } diff --git a/src/iface/interface.rs b/src/iface/interface.rs index 7469e91..ac7c43b 100644 --- a/src/iface/interface.rs +++ b/src/iface/interface.rs @@ -485,7 +485,7 @@ where /// # Panics /// This function may panic if the handle does not belong to this socket set /// or the socket has the wrong type. - pub fn get_socket>(&mut self, handle: SocketHandle) -> SocketRef { + pub fn get_socket>(&mut self, handle: SocketHandle) -> &mut T { self.sockets.get(handle) } @@ -830,7 +830,7 @@ where let _caps = device.capabilities(); let mut emitted_any = false; - for mut socket in sockets.iter_mut() { + for socket in sockets.iter_mut() { if !socket .meta_mut() .egress_permitted(cx.now, |ip_addr| inner.has_neighbor(cx, &ip_addr)) @@ -1202,7 +1202,7 @@ impl<'a> InterfaceInner<'a> { // Look for UDP sockets that will accept the UDP packet. // If it does not accept the packet, then send an ICMP message. - for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { + for udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { if !udp_socket.accepts(&IpRepr::Ipv6(ipv6_repr), &udp_repr) { continue; } @@ -1328,7 +1328,7 @@ impl<'a> InterfaceInner<'a> { let mut handled_by_raw_socket = false; // Pass every IP packet to all raw sockets we have registered. - for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) { + for raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) { if !raw_socket.accepts(ip_repr) { continue; } @@ -1460,7 +1460,7 @@ impl<'a> InterfaceInner<'a> { if udp_packet.src_port() == DHCP_SERVER_PORT && udp_packet.dst_port() == DHCP_CLIENT_PORT { - if let Some(mut dhcp_socket) = + if let Some(dhcp_socket) = sockets.iter_mut().filter_map(Dhcpv4Socket::downcast).next() { let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); @@ -1639,7 +1639,7 @@ impl<'a> InterfaceInner<'a> { let mut handled_by_icmp_socket = false; #[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))] - for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { + for icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) { continue; } @@ -1825,7 +1825,7 @@ impl<'a> InterfaceInner<'a> { let mut handled_by_icmp_socket = false; #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))] - for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { + for icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) { continue; } @@ -1949,7 +1949,7 @@ impl<'a> InterfaceInner<'a> { let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?; let udp_payload = udp_packet.payload(); - for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { + for udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { if !udp_socket.accepts(&ip_repr, &udp_repr) { continue; } @@ -2006,7 +2006,7 @@ impl<'a> InterfaceInner<'a> { let tcp_packet = TcpPacket::new_checked(ip_payload)?; let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?; - for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) { + for tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) { if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue; } @@ -2944,7 +2944,7 @@ mod test { { // Bind the socket to port 68 - let mut socket = iface.get_socket::(socket_handle); + let socket = iface.get_socket::(socket_handle); assert_eq!(socket.bind(68), Ok(())); assert!(!socket.can_recv()); assert!(socket.can_send()); @@ -2971,7 +2971,7 @@ mod test { { // Make sure the payload to the UDP packet processed by process_udp is // appended to the bound sockets rx_buffer - let mut socket = iface.get_socket::(socket_handle); + let socket = iface.get_socket::(socket_handle); assert!(socket.can_recv()); assert_eq!( socket.recv(), @@ -3443,7 +3443,7 @@ mod test { let echo_data = &[0xff; 16]; { - let mut socket = iface.get_socket::(socket_handle); + let socket = iface.get_socket::(socket_handle); // Bind to the ID 0x1234 assert_eq!(socket.bind(IcmpEndpoint::Ident(ident)), Ok(())); } @@ -3494,7 +3494,7 @@ mod test { ); { - let mut socket = iface.get_socket::(socket_handle); + let socket = iface.get_socket::(socket_handle); assert!(socket.can_recv()); assert_eq!( socket.recv(), @@ -3856,7 +3856,7 @@ mod test { let udp_socket_handle = iface.add_socket(udp_socket); { // Bind the socket to port 68 - let mut socket = iface.get_socket::(udp_socket_handle); + let socket = iface.get_socket::(udp_socket_handle); assert_eq!(socket.bind(68), Ok(())); assert!(!socket.can_recv()); assert!(socket.can_send()); @@ -3929,7 +3929,7 @@ mod test { { // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP - let mut socket = iface.get_socket::(udp_socket_handle); + let socket = iface.get_socket::(udp_socket_handle); assert!(socket.can_recv()); assert_eq!( socket.recv(), diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 9836c4b..30262b0 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -24,7 +24,6 @@ mod icmp; mod meta; #[cfg(feature = "socket-raw")] mod raw; -mod ref_; mod set; #[cfg(feature = "socket-tcp")] mod tcp; @@ -59,9 +58,6 @@ pub use self::dhcpv4::{Config as Dhcpv4Config, Dhcpv4Socket, Event as Dhcpv4Even pub use self::set::{Handle as SocketHandle, Item as SocketSetItem, Set as SocketSet}; pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut}; -pub use self::ref_::Ref as SocketRef; -pub(crate) use self::ref_::Session as SocketSession; - /// Gives an indication on the next time the socket should be polled. #[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -144,25 +140,19 @@ impl<'a> Socket<'a> { } } -impl<'a> SocketSession for Socket<'a> { - fn finish(&mut self) { - dispatch_socket!(mut self, |socket| socket.finish()) - } -} - /// A conversion trait for network sockets. -pub trait AnySocket<'a>: SocketSession + Sized { - fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a>>) -> Option>; +pub trait AnySocket<'a>: Sized { + fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self>; } macro_rules! from_socket { ($socket:ty, $variant:ident) => { impl<'a> AnySocket<'a> for $socket { - fn downcast<'c>(ref_: SocketRef<'c, Socket<'a>>) -> Option> { - if let Socket::$variant(ref mut socket) = SocketRef::into_inner(ref_) { - Some(SocketRef::new(socket)) - } else { - None + fn downcast<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> { + #[allow(unreachable_patterns)] + match socket { + Socket::$variant(socket) => Some(socket), + _ => None, } } } diff --git a/src/socket/ref_.rs b/src/socket/ref_.rs deleted file mode 100644 index 93992c9..0000000 --- a/src/socket/ref_.rs +++ /dev/null @@ -1,87 +0,0 @@ -use core::ops::{Deref, DerefMut}; - -/// A trait for tracking a socket usage session. -/// -/// Allows implementation of custom drop logic that runs only if the socket was changed -/// in specific ways. For example, drop logic for UDP would check if the local endpoint -/// has changed, and if yes, notify the socket set. -#[doc(hidden)] -pub trait Session { - fn finish(&mut self) {} -} - -#[cfg(feature = "socket-raw")] -impl<'a> Session for crate::socket::RawSocket<'a> {} -#[cfg(all( - feature = "socket-icmp", - any(feature = "proto-ipv4", feature = "proto-ipv6") -))] -impl<'a> Session for crate::socket::IcmpSocket<'a> {} -#[cfg(feature = "socket-udp")] -impl<'a> Session for crate::socket::UdpSocket<'a> {} -#[cfg(feature = "socket-tcp")] -impl<'a> Session for crate::socket::TcpSocket<'a> {} -#[cfg(feature = "socket-dhcpv4")] -impl Session for crate::socket::Dhcpv4Socket {} - -/// A smart pointer to a socket. -/// -/// Allows the network stack to efficiently determine if the socket state was changed in any way. -pub struct Ref<'a, T: Session + 'a> { - /// Reference to the socket. - /// - /// This is almost always `Some` except when dropped in `into_inner` which removes the socket - /// reference. This properly tracks the initialization state without any additional bytes as - /// the `None` variant occupies the `0` pattern which is invalid for the reference. - socket: Option<&'a mut T>, -} - -impl<'a, T: Session + 'a> Ref<'a, T> { - /// Wrap a pointer to a socket to make a smart pointer. - /// - /// Calling this function is only necessary if your code is using [into_inner]. - /// - /// [into_inner]: #method.into_inner - pub fn new(socket: &'a mut T) -> Self { - Ref { - socket: Some(socket), - } - } - - /// Unwrap a smart pointer to a socket. - /// - /// The finalization code is not run. Prompt operation of the network stack depends - /// on wrapping the returned pointer back and dropping it. - /// - /// Calling this function is only necessary to achieve composability if you *must* - /// map a `&mut SocketRef<'a, XSocket>` to a `&'a mut XSocket` (note the lifetimes); - /// be sure to call [new] afterwards. - /// - /// [new]: #method.new - pub fn into_inner(mut ref_: Self) -> &'a mut T { - ref_.socket.take().unwrap() - } -} - -impl<'a, T: Session> Deref for Ref<'a, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - // Deref is only used while the socket is still in place (into inner has not been called). - self.socket.as_ref().unwrap() - } -} - -impl<'a, T: Session> DerefMut for Ref<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - self.socket.as_mut().unwrap() - } -} - -impl<'a, T: Session> Drop for Ref<'a, T> { - fn drop(&mut self) { - if let Some(socket) = self.socket.take() { - Session::finish(socket); - } - } -} diff --git a/src/socket/set.rs b/src/socket/set.rs index 4c37efa..cb96dfe 100644 --- a/src/socket/set.rs +++ b/src/socket/set.rs @@ -3,7 +3,7 @@ use managed::ManagedSlice; #[cfg(feature = "socket-tcp")] use crate::socket::TcpState; -use crate::socket::{AnySocket, Socket, SocketRef}; +use crate::socket::{AnySocket, Socket}; /// An item of a socket set. /// @@ -84,10 +84,11 @@ impl<'a> Set<'a> { /// # Panics /// This function may panic if the handle does not belong to this socket set /// or the socket has the wrong type. - pub fn get>(&mut self, handle: Handle) -> SocketRef { + pub fn get>(&mut self, handle: Handle) -> &mut T { match self.sockets[handle.0].as_mut() { - Some(item) => T::downcast(SocketRef::new(&mut item.socket)) - .expect("handle refers to a socket of a wrong type"), + Some(item) => { + T::downcast(&mut item.socket).expect("handle refers to a socket of a wrong type") + } None => panic!("handle does not refer to a valid socket"), } } @@ -179,7 +180,7 @@ impl<'a> Set<'a> { } } - /// Iterate every socket in this set, as SocketRef. + /// Iterate every socket in this set. pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'a> { IterMut { lower: self.sockets.iter_mut(), @@ -217,12 +218,12 @@ pub struct IterMut<'a, 'b: 'a> { } impl<'a, 'b: 'a> Iterator for IterMut<'a, 'b> { - type Item = SocketRef<'a, Socket<'b>>; + type Item = &'a mut Socket<'b>; fn next(&mut self) -> Option { for item_opt in &mut self.lower { if let Some(item) = item_opt.as_mut() { - return Some(SocketRef::new(&mut item.socket)); + return Some(&mut item.socket); } } None