diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 91eccf2..dafbe24 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -43,8 +43,6 @@ jobs: # Test alloc feature which requires nightly. - rust: nightly features: alloc medium-ethernet proto-ipv4 proto-ipv6 socket-raw socket-udp socket-tcp socket-icmp - - rust: nightly - features: alloc proto-ipv4 proto-ipv6 socket-raw socket-udp socket-tcp socket-icmp steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 diff --git a/src/iface/interface.rs b/src/iface/interface.rs index eb6e186..f4be495 100644 --- a/src/iface/interface.rs +++ b/src/iface/interface.rs @@ -45,7 +45,6 @@ struct InterfaceInner<'a> { /// When to report for (all or) the next multicast group membership via IGMP #[cfg(feature = "proto-igmp")] igmp_report_state: IgmpReportState, - device_capabilities: DeviceCapabilities, } /// A builder structure used for creating a network interface. @@ -230,7 +229,6 @@ let iface = InterfaceBuilder::new(device) #[cfg(feature = "proto-ipv4")] any_ip: self.any_ip, routes: self.routes, - device_capabilities, #[cfg(feature = "medium-ethernet")] neighbor_cache, #[cfg(feature = "proto-igmp")] @@ -420,6 +418,7 @@ impl<'a, DeviceT> Interface<'a, DeviceT> /// Returns `Ok(announce_sent)` if the address was added successfully, where `annouce_sent` /// indicates whether an initial immediate announcement has been sent. pub fn join_multicast_group>(&mut self, addr: T, _timestamp: Instant) -> Result { + match addr.into() { #[cfg(feature = "proto-igmp")] IpAddress::Ipv4(addr) => { @@ -430,9 +429,10 @@ impl<'a, DeviceT> Interface<'a, DeviceT> Ok(false) } else if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) { + let cx = self.context(_timestamp); // Send initial membership report let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch_ip(tx_token, _timestamp, pkt)?; + self.inner.dispatch_ip(&cx, tx_token, pkt)?; Ok(true) } else { Ok(false) @@ -456,9 +456,10 @@ impl<'a, DeviceT> Interface<'a, DeviceT> if was_not_present { Ok(false) } else if let Some(pkt) = self.inner.igmp_leave_packet(addr) { + let cx = self.context(_timestamp); // Send group leave packet let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch_ip(tx_token, _timestamp, pkt)?; + self.inner.dispatch_ip(&cx, tx_token, pkt)?; Ok(true) } else { Ok(false) @@ -535,13 +536,15 @@ impl<'a, DeviceT> Interface<'a, DeviceT> /// a very common occurrence and on a production system it should not even /// be logged. pub fn poll(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> Result { + let cx = self.context(timestamp); + let mut readiness_may_have_changed = false; loop { - let processed_any = self.socket_ingress(sockets, timestamp); - let emitted_any = self.socket_egress(sockets, timestamp)?; + let processed_any = self.socket_ingress(&cx, sockets); + let emitted_any = self.socket_egress(&cx, sockets)?; #[cfg(feature = "proto-igmp")] - self.igmp_egress(timestamp)?; + self.igmp_egress(&cx, timestamp)?; if processed_any || emitted_any { readiness_may_have_changed = true; @@ -561,10 +564,12 @@ impl<'a, DeviceT> Interface<'a, DeviceT> /// [poll]: #method.poll /// [Instant]: struct.Instant.html pub fn poll_at(&self, sockets: &SocketSet, timestamp: Instant) -> Option { + let cx = self.context(timestamp); + sockets.iter().filter_map(|socket| { - let socket_poll_at = socket.poll_at(); + let socket_poll_at = socket.poll_at(&cx); match socket.meta().poll_at(socket_poll_at, |ip_addr| - self.inner.has_neighbor(&ip_addr, timestamp)) { + self.inner.has_neighbor(&cx, &ip_addr)) { PollAt::Ingress => None, PollAt::Time(instant) => Some(instant), PollAt::Now => Some(Instant::from_millis(0)), @@ -592,19 +597,19 @@ impl<'a, DeviceT> Interface<'a, DeviceT> } } - fn socket_ingress(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> bool { + fn socket_ingress(&mut self, cx: &Context, sockets: &mut SocketSet) -> bool { let mut processed_any = false; let &mut Self { ref mut device, ref mut inner } = self; while let Some((rx_token, tx_token)) = device.receive() { - if let Err(err) = rx_token.consume(timestamp, |frame| { - match inner.device_capabilities.medium { + if let Err(err) = rx_token.consume(cx.now, |frame| { + match cx.caps.medium { #[cfg(feature = "medium-ethernet")] Medium::Ethernet => { - match inner.process_ethernet(sockets, timestamp, &frame) { + match inner.process_ethernet(cx, sockets, &frame) { Ok(response) => { processed_any = true; if let Some(packet) = response { - if let Err(err) = inner.dispatch(tx_token, timestamp, packet) { + if let Err(err) = inner.dispatch(cx, tx_token, packet) { net_debug!("Failed to send response: {}", err); } } @@ -619,11 +624,11 @@ impl<'a, DeviceT> Interface<'a, DeviceT> } #[cfg(feature = "medium-ip")] Medium::Ip => { - match inner.process_ip(sockets, timestamp, &frame) { + match inner.process_ip(cx, sockets, &frame) { Ok(response) => { processed_any = true; if let Some(packet) = response { - if let Err(err) = inner.dispatch_ip(tx_token, timestamp, packet) { + if let Err(err) = inner.dispatch_ip(cx, tx_token, packet) { net_debug!("Failed to send response: {}", err); } } @@ -642,13 +647,13 @@ impl<'a, DeviceT> Interface<'a, DeviceT> processed_any } - fn socket_egress(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> Result { + fn socket_egress(&mut self, cx: &Context, sockets: &mut SocketSet) -> Result { let _caps = self.device.capabilities(); let mut emitted_any = false; for mut socket in sockets.iter_mut() { - if !socket.meta_mut().egress_permitted(timestamp, |ip_addr| - self.inner.has_neighbor(&ip_addr, timestamp)) { + if !socket.meta_mut().egress_permitted(cx.now, |ip_addr| + self.inner.has_neighbor(cx, &ip_addr)) { continue } @@ -661,28 +666,21 @@ impl<'a, DeviceT> Interface<'a, DeviceT> let response = $response; neighbor_addr = Some(response.ip_repr().dst_addr()); let tx_token = device.transmit().ok_or(Error::Exhausted)?; - device_result = inner.dispatch_ip(tx_token, timestamp, response); + device_result = inner.dispatch_ip(cx, tx_token, response); device_result }) } - let _ip_mtu = match _caps.medium { - #[cfg(feature = "medium-ethernet")] - Medium::Ethernet => _caps.max_transmission_unit - EthernetFrame::<&[u8]>::header_len(), - #[cfg(feature = "medium-ip")] - Medium::Ip => _caps.max_transmission_unit, - }; - let socket_result = match *socket { #[cfg(feature = "socket-raw")] Socket::Raw(ref mut socket) => - socket.dispatch(&_caps.checksum, |response| + socket.dispatch(cx, |response| respond!(IpPacket::Raw(response))), #[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] Socket::Icmp(ref mut socket) => - socket.dispatch(|response| { + socket.dispatch(cx, |response| { match response { #[cfg(feature = "proto-ipv4")] (IpRepr::Ipv4(ipv4_repr), IcmpRepr::Ipv4(icmpv4_repr)) => @@ -695,17 +693,17 @@ impl<'a, DeviceT> Interface<'a, DeviceT> }), #[cfg(feature = "socket-udp")] Socket::Udp(ref mut socket) => - socket.dispatch(|response| + socket.dispatch(cx, |response| respond!(IpPacket::Udp(response))), #[cfg(feature = "socket-tcp")] Socket::Tcp(ref mut socket) => { - socket.dispatch(timestamp, _ip_mtu, |response| + socket.dispatch(cx, |response| respond!(IpPacket::Tcp(response))) } #[cfg(feature = "socket-dhcpv4")] Socket::Dhcpv4(ref mut socket) => // todo don't unwrap - socket.dispatch(timestamp, inner.ethernet_addr.unwrap(), _ip_mtu, |response| + socket.dispatch(cx, |response| respond!(IpPacket::Dhcpv4(response))), }; @@ -717,7 +715,7 @@ impl<'a, DeviceT> Interface<'a, DeviceT> // requests from the socket. However, without an additional rate limiting // mechanism, we would spin on every socket that has yet to discover its // neighboor. - socket.meta_mut().neighbor_missing(timestamp, + socket.meta_mut().neighbor_missing(cx.now, neighbor_addr.expect("non-IP response packet")); break } @@ -735,14 +733,14 @@ impl<'a, DeviceT> Interface<'a, DeviceT> /// Depending on `igmp_report_state` and the therein contained /// timeouts, send IGMP membership reports. #[cfg(feature = "proto-igmp")] - fn igmp_egress(&mut self, timestamp: Instant) -> Result { + fn igmp_egress(&mut self, cx: &Context, timestamp: Instant) -> Result { match self.inner.igmp_report_state { IgmpReportState::ToSpecificQuery { version, timeout, group } if timestamp >= timeout => { if let Some(pkt) = self.inner.igmp_report_packet(version, group) { // Send initial membership report let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch_ip(tx_token, timestamp, pkt)?; + self.inner.dispatch_ip(cx, tx_token, pkt)?; } self.inner.igmp_report_state = IgmpReportState::Inactive; @@ -760,7 +758,7 @@ impl<'a, DeviceT> Interface<'a, DeviceT> if let Some(pkt) = self.inner.igmp_report_packet(version, addr) { // Send initial membership report let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch_ip(tx_token, timestamp, pkt)?; + self.inner.dispatch_ip(cx, tx_token, pkt)?; } let next_timeout = (timeout + interval).max(timestamp); @@ -779,6 +777,15 @@ impl<'a, DeviceT> Interface<'a, DeviceT> _ => Ok(false) } } + + fn context(&self, now: Instant) -> Context { + Context { + now, + caps: self.device.capabilities(), + #[cfg(feature = "medium-ethernet")] + ethernet_address: self.inner.ethernet_addr, + } + } } impl<'a> InterfaceInner<'a> { @@ -852,7 +859,7 @@ impl<'a> InterfaceInner<'a> { #[cfg(feature = "medium-ethernet")] fn process_ethernet<'frame, T: AsRef<[u8]>> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, frame: &'frame T) -> + (&mut self, cx: &Context, sockets: &mut SocketSet, frame: &'frame T) -> Result>> { let eth_frame = EthernetFrame::new_checked(frame)?; @@ -868,7 +875,7 @@ impl<'a> InterfaceInner<'a> { match eth_frame.ethertype() { #[cfg(feature = "proto-ipv4")] EthernetProtocol::Arp => - self.process_arp(timestamp, ð_frame), + self.process_arp(cx.now, ð_frame), #[cfg(feature = "proto-ipv4")] EthernetProtocol::Ipv4 => { let ipv4_packet = Ipv4Packet::new_checked(eth_frame.payload())?; @@ -876,11 +883,11 @@ impl<'a> InterfaceInner<'a> { // Fill the neighbor cache from IP header of unicast frames. let ip_addr = IpAddress::Ipv4(ipv4_packet.src_addr()); if self.in_same_network(&ip_addr) { - self.neighbor_cache.as_mut().unwrap().fill(ip_addr, eth_frame.src_addr(), timestamp); + self.neighbor_cache.as_mut().unwrap().fill(ip_addr, eth_frame.src_addr(), cx.now); } } - self.process_ipv4(sockets, timestamp, &ipv4_packet).map(|o| o.map(EthernetPacket::Ip)) + self.process_ipv4(cx, sockets, &ipv4_packet).map(|o| o.map(EthernetPacket::Ip)) } #[cfg(feature = "proto-ipv6")] EthernetProtocol::Ipv6 => { @@ -889,12 +896,12 @@ impl<'a> InterfaceInner<'a> { // Fill the neighbor cache from IP header of unicast frames. let ip_addr = IpAddress::Ipv6(ipv6_packet.src_addr()); if self.in_same_network(&ip_addr) && - self.neighbor_cache.as_mut().unwrap().lookup(&ip_addr, timestamp).found() { - self.neighbor_cache.as_mut().unwrap().fill(ip_addr, eth_frame.src_addr(), timestamp); + self.neighbor_cache.as_mut().unwrap().lookup(&ip_addr, cx.now).found() { + self.neighbor_cache.as_mut().unwrap().fill(ip_addr, eth_frame.src_addr(), cx.now); } } - self.process_ipv6(sockets, timestamp, &ipv6_packet).map(|o| o.map(EthernetPacket::Ip)) + self.process_ipv6(cx, sockets, &ipv6_packet).map(|o| o.map(EthernetPacket::Ip)) } // Drop all other traffic. _ => Err(Error::Unrecognized), @@ -903,19 +910,19 @@ impl<'a> InterfaceInner<'a> { #[cfg(feature = "medium-ip")] fn process_ip<'frame, T: AsRef<[u8]>> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, ip_payload: &'frame T) -> + (&mut self, cx: &Context, sockets: &mut SocketSet, ip_payload: &'frame T) -> Result>> { match IpVersion::of_packet(ip_payload.as_ref()) { #[cfg(feature = "proto-ipv4")] Ok(IpVersion::Ipv4) => { let ipv4_packet = Ipv4Packet::new_checked(ip_payload)?; - self.process_ipv4(sockets, timestamp, &ipv4_packet) + self.process_ipv4(cx, sockets, &ipv4_packet) } #[cfg(feature = "proto-ipv6")] Ok(IpVersion::Ipv6) => { let ipv6_packet = Ipv6Packet::new_checked(ip_payload)?; - self.process_ipv6(sockets, timestamp, &ipv6_packet) + self.process_ipv6(cx, sockets, &ipv6_packet) } // Drop all other traffic. _ => Err(Error::Unrecognized), @@ -963,16 +970,15 @@ impl<'a> InterfaceInner<'a> { } #[cfg(all(any(feature = "proto-ipv4", feature = "proto-ipv6"), feature = "socket-raw"))] - fn raw_socket_filter<'frame>(&mut self, sockets: &mut SocketSet, ip_repr: &IpRepr, + fn raw_socket_filter<'frame>(&mut self, cx: &Context, sockets: &mut SocketSet, ip_repr: &IpRepr, ip_payload: &'frame [u8]) -> bool { - let checksum_caps = self.device_capabilities.checksum.clone(); let mut handled_by_raw_socket = false; // Pass every IP packet to all raw sockets we have registered. for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) { if !raw_socket.accepts(&ip_repr) { continue } - match raw_socket.process(&ip_repr, ip_payload, &checksum_caps) { + match raw_socket.process(cx, &ip_repr, ip_payload) { // The packet is valid and handled by socket. Ok(()) => handled_by_raw_socket = true, // The socket buffer is full or the packet was truncated @@ -986,7 +992,7 @@ impl<'a> InterfaceInner<'a> { #[cfg(feature = "proto-ipv6")] fn process_ipv6<'frame, T: AsRef<[u8]> + ?Sized> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, + (&mut self, cx: &Context, sockets: &mut SocketSet, ipv6_packet: &Ipv6Packet<&'frame T>) -> Result>> { let ipv6_repr = Ipv6Repr::parse(&ipv6_packet)?; @@ -1000,11 +1006,11 @@ impl<'a> InterfaceInner<'a> { let ip_payload = ipv6_packet.payload(); #[cfg(feature = "socket-raw")] - let handled_by_raw_socket = self.raw_socket_filter(sockets, &ipv6_repr.into(), ip_payload); + let handled_by_raw_socket = self.raw_socket_filter(cx, sockets, &ipv6_repr.into(), ip_payload); #[cfg(not(feature = "socket-raw"))] let handled_by_raw_socket = false; - self.process_nxt_hdr(sockets, timestamp, ipv6_repr, ipv6_repr.next_header, + self.process_nxt_hdr(cx, sockets, ipv6_repr, ipv6_repr.next_header, handled_by_raw_socket, ip_payload) } @@ -1012,24 +1018,24 @@ impl<'a> InterfaceInner<'a> { /// function. #[cfg(feature = "proto-ipv6")] fn process_nxt_hdr<'frame> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, ipv6_repr: Ipv6Repr, + (&mut self, cx: &Context, sockets: &mut SocketSet, ipv6_repr: Ipv6Repr, nxt_hdr: IpProtocol, handled_by_raw_socket: bool, ip_payload: &'frame [u8]) -> Result>> { match nxt_hdr { IpProtocol::Icmpv6 => - self.process_icmpv6(sockets, timestamp, ipv6_repr.into(), ip_payload), + self.process_icmpv6(cx, sockets, ipv6_repr.into(), ip_payload), #[cfg(feature = "socket-udp")] IpProtocol::Udp => - self.process_udp(sockets, ipv6_repr.into(), handled_by_raw_socket, ip_payload), + self.process_udp(cx, sockets, ipv6_repr.into(), handled_by_raw_socket, ip_payload), #[cfg(feature = "socket-tcp")] IpProtocol::Tcp => - self.process_tcp(sockets, timestamp, ipv6_repr.into(), ip_payload), + self.process_tcp(cx, sockets, ipv6_repr.into(), ip_payload), IpProtocol::HopByHop => - self.process_hopbyhop(sockets, timestamp, ipv6_repr, handled_by_raw_socket, ip_payload), + self.process_hopbyhop(cx, sockets, ipv6_repr, handled_by_raw_socket, ip_payload), #[cfg(feature = "socket-raw")] _ if handled_by_raw_socket => @@ -1053,12 +1059,11 @@ impl<'a> InterfaceInner<'a> { #[cfg(feature = "proto-ipv4")] fn process_ipv4<'frame, T: AsRef<[u8]> + ?Sized> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, + (&mut self, cx: &Context, sockets: &mut SocketSet, ipv4_packet: &Ipv4Packet<&'frame T>) -> Result>> { - let checksum_caps = self.device_capabilities.checksum.clone(); - let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &checksum_caps)?; + let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &cx.caps.checksum)?; if !self.is_unicast_v4(ipv4_repr.src_addr) { // Discard packets with non-unicast source addresses. @@ -1070,7 +1075,7 @@ impl<'a> InterfaceInner<'a> { let ip_payload = ipv4_packet.payload(); #[cfg(feature = "socket-raw")] - let handled_by_raw_socket = self.raw_socket_filter(sockets, &ip_repr, ip_payload); + let handled_by_raw_socket = self.raw_socket_filter(cx, sockets, &ip_repr, ip_payload); #[cfg(not(feature = "socket-raw"))] let handled_by_raw_socket = false; @@ -1084,14 +1089,10 @@ impl<'a> InterfaceInner<'a> { if udp_packet.src_port() == DHCP_SERVER_PORT && udp_packet.dst_port() == DHCP_CLIENT_PORT { if let Some(mut dhcp_socket) = sockets.iter_mut().filter_map(Dhcpv4Socket::downcast).next() { let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); - let checksum_caps = self.device_capabilities.checksum.clone(); - let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?; + let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?; let udp_payload = udp_packet.payload(); - // NOTE(unwrap): we checked for is_some above. - let ethernet_addr = self.ethernet_addr.unwrap(); - - match dhcp_socket.process(timestamp, ethernet_addr, &ipv4_repr, &udp_repr, udp_payload) { + match dhcp_socket.process(cx, &ipv4_repr, &udp_repr, udp_payload) { // The packet is valid and handled by socket. Ok(()) => return Ok(None), // The packet is malformed, or the socket buffer is full. @@ -1109,7 +1110,7 @@ impl<'a> InterfaceInner<'a> { // Ignore IP packets not directed at us, or broadcast, or any of the multicast groups. // If AnyIP is enabled, also check if the packet is routed locally. if !self.any_ip || - self.routes.lookup(&IpAddress::Ipv4(ipv4_repr.dst_addr), timestamp) + self.routes.lookup(&IpAddress::Ipv4(ipv4_repr.dst_addr), cx.now) .map_or(true, |router_addr| !self.has_ip_addr(router_addr)) { return Ok(None); } @@ -1117,19 +1118,19 @@ impl<'a> InterfaceInner<'a> { match ipv4_repr.protocol { IpProtocol::Icmp => - self.process_icmpv4(sockets, ip_repr, ip_payload), + self.process_icmpv4(cx, sockets, ip_repr, ip_payload), #[cfg(feature = "proto-igmp")] IpProtocol::Igmp => - self.process_igmp(timestamp, ipv4_repr, ip_payload), + self.process_igmp(cx, ipv4_repr, ip_payload), #[cfg(feature = "socket-udp")] IpProtocol::Udp => - self.process_udp(sockets, ip_repr, handled_by_raw_socket, ip_payload), + self.process_udp(cx, sockets, ip_repr, handled_by_raw_socket, ip_payload), #[cfg(feature = "socket-tcp")] IpProtocol::Tcp => - self.process_tcp(sockets, timestamp, ip_repr, ip_payload), + self.process_tcp(cx, sockets, ip_repr, ip_payload), _ if handled_by_raw_socket => Ok(None), @@ -1179,7 +1180,7 @@ impl<'a> InterfaceInner<'a> { /// Membership must not be reported immediately in order to avoid flooding the network /// after a query is broadcasted by a router; this is not currently done. #[cfg(feature = "proto-igmp")] - fn process_igmp<'frame>(&mut self, timestamp: Instant, ipv4_repr: Ipv4Repr, + fn process_igmp<'frame>(&mut self, cx: &Context, ipv4_repr: Ipv4Repr, ip_payload: &'frame [u8]) -> Result>> { let igmp_packet = IgmpPacket::new_checked(ip_payload)?; let igmp_repr = IgmpRepr::parse(&igmp_packet)?; @@ -1204,7 +1205,7 @@ impl<'a> InterfaceInner<'a> { } }; self.igmp_report_state = IgmpReportState::ToGeneralQuery { - version, timeout: timestamp + interval, interval, next_index: 0 + version, timeout: cx.now + interval, interval, next_index: 0 }; } } else { @@ -1213,7 +1214,7 @@ impl<'a> InterfaceInner<'a> { // Don't respond immediately let timeout = max_resp_time / 4; self.igmp_report_state = IgmpReportState::ToSpecificQuery { - version, timeout: timestamp + timeout, group: group_addr + version, timeout: cx.now + timeout, group: group_addr }; } } @@ -1228,22 +1229,21 @@ impl<'a> InterfaceInner<'a> { } #[cfg(feature = "proto-ipv6")] - fn process_icmpv6<'frame>(&mut self, _sockets: &mut SocketSet, _timestamp: Instant, + fn process_icmpv6<'frame>(&mut self, cx: &Context, _sockets: &mut SocketSet, ip_repr: IpRepr, ip_payload: &'frame [u8]) -> Result>> { let icmp_packet = Icmpv6Packet::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); let icmp_repr = Icmpv6Repr::parse(&ip_repr.src_addr(), &ip_repr.dst_addr(), - &icmp_packet, &checksum_caps)?; + &icmp_packet, &cx.caps.checksum)?; #[cfg(feature = "socket-icmp")] let mut handled_by_icmp_socket = false; #[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))] for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { - if !icmp_socket.accepts(&ip_repr, &icmp_repr.into(), &checksum_caps) { continue } + if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) { continue } - match icmp_socket.process(&ip_repr, &icmp_repr.into(), &checksum_caps) { + match icmp_socket.process(cx, &ip_repr, &icmp_repr.into()) { // The packet is valid and handled by socket. Ok(()) => handled_by_icmp_socket = true, // The socket buffer is full. @@ -1271,7 +1271,7 @@ impl<'a> InterfaceInner<'a> { // Forward any NDISC packets to the ndisc packet handler #[cfg(feature = "medium-ethernet")] Icmpv6Repr::Ndisc(repr) if ip_repr.hop_limit() == 0xff => match ip_repr { - IpRepr::Ipv6(ipv6_repr) => self.process_ndisc(_timestamp, ipv6_repr, repr), + IpRepr::Ipv6(ipv6_repr) => self.process_ndisc(cx.now, ipv6_repr, repr), _ => Ok(None) }, @@ -1332,7 +1332,7 @@ impl<'a> InterfaceInner<'a> { } #[cfg(feature = "proto-ipv6")] - fn process_hopbyhop<'frame>(&mut self, sockets: &mut SocketSet, timestamp: Instant, + fn process_hopbyhop<'frame>(&mut self, cx: &Context, sockets: &mut SocketSet, ipv6_repr: Ipv6Repr, handled_by_raw_socket: bool, ip_payload: &'frame [u8]) -> Result>> { @@ -1357,26 +1357,25 @@ impl<'a> InterfaceInner<'a> { } } } - self.process_nxt_hdr(sockets, timestamp, ipv6_repr, hbh_repr.next_header, + self.process_nxt_hdr(cx, sockets, ipv6_repr, hbh_repr.next_header, handled_by_raw_socket, &ip_payload[hbh_repr.buffer_len()..]) } #[cfg(feature = "proto-ipv4")] - fn process_icmpv4<'frame>(&self, _sockets: &mut SocketSet, ip_repr: IpRepr, + fn process_icmpv4<'frame>(&self, cx: &Context, _sockets: &mut SocketSet, ip_repr: IpRepr, ip_payload: &'frame [u8]) -> Result>> { let icmp_packet = Icmpv4Packet::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &checksum_caps)?; + let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &cx.caps.checksum)?; #[cfg(feature = "socket-icmp")] let mut handled_by_icmp_socket = false; #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))] for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { - if !icmp_socket.accepts(&ip_repr, &icmp_repr.into(), &checksum_caps) { continue } + if !icmp_socket.accepts(cx, &ip_repr, &icmp_repr.into()) { continue } - match icmp_socket.process(&ip_repr, &icmp_repr.into(), &checksum_caps) { + match icmp_socket.process(cx, &ip_repr, &icmp_repr.into()) { // The packet is valid and handled by socket. Ok(()) => handled_by_icmp_socket = true, // The socket buffer is full. @@ -1472,20 +1471,19 @@ impl<'a> InterfaceInner<'a> { } #[cfg(feature = "socket-udp")] - fn process_udp<'frame>(&self, sockets: &mut SocketSet, + fn process_udp<'frame>(&self, cx: &Context, sockets: &mut SocketSet, ip_repr: IpRepr, handled_by_raw_socket: bool, ip_payload: &'frame [u8]) -> Result>> { let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); let udp_packet = UdpPacket::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?; + let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?; let udp_payload = udp_packet.payload(); for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { if !udp_socket.accepts(&ip_repr, &udp_repr) { continue } - match udp_socket.process(&ip_repr, &udp_repr, udp_payload) { + match udp_socket.process(cx, &ip_repr, &udp_repr, udp_payload) { // The packet is valid and handled by socket. Ok(()) => return Ok(None), // The packet is malformed, or the socket buffer is full. @@ -1528,19 +1526,18 @@ impl<'a> InterfaceInner<'a> { } #[cfg(feature = "socket-tcp")] - fn process_tcp<'frame>(&self, sockets: &mut SocketSet, timestamp: Instant, + fn process_tcp<'frame>(&self, cx: &Context, sockets: &mut SocketSet, ip_repr: IpRepr, ip_payload: &'frame [u8]) -> Result>> { let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); let tcp_packet = TcpPacket::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &checksum_caps)?; + let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &cx.caps.checksum)?; for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) { if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue } - match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) { + match tcp_socket.process(cx, &ip_repr, &tcp_repr) { // The packet is valid and handled by socket. Ok(reply) => return Ok(reply.map(IpPacket::Tcp)), // The packet is malformed, or doesn't match the socket state, @@ -1559,7 +1556,7 @@ impl<'a> InterfaceInner<'a> { } #[cfg(feature = "medium-ethernet")] - fn dispatch(&mut self, tx_token: Tx, timestamp: Instant, + fn dispatch(&mut self, cx: &Context, tx_token: Tx, packet: EthernetPacket) -> Result<()> where Tx: TxToken { @@ -1571,7 +1568,7 @@ impl<'a> InterfaceInner<'a> { ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr, }; - self.dispatch_ethernet(tx_token, timestamp, arp_repr.buffer_len(), |mut frame| { + self.dispatch_ethernet(cx, tx_token, arp_repr.buffer_len(), |mut frame| { frame.set_dst_addr(dst_hardware_addr); frame.set_ethertype(EthernetProtocol::Arp); @@ -1580,18 +1577,18 @@ impl<'a> InterfaceInner<'a> { }) }, EthernetPacket::Ip(packet) => { - self.dispatch_ip(tx_token, timestamp, packet) + self.dispatch_ip(cx, tx_token, packet) }, } } #[cfg(feature = "medium-ethernet")] - fn dispatch_ethernet(&mut self, tx_token: Tx, timestamp: Instant, + fn dispatch_ethernet(&mut self, cx: &Context, tx_token: Tx, buffer_len: usize, f: F) -> Result<()> where Tx: TxToken, F: FnOnce(EthernetFrame<&mut [u8]>) { let tx_len = EthernetFrame::<&[u8]>::buffer_len(buffer_len); - tx_token.consume(timestamp, tx_len, |tx_buffer| { + tx_token.consume(cx.now, tx_len, |tx_buffer| { debug_assert!(tx_buffer.as_ref().len() == tx_len); let mut frame = EthernetFrame::new_unchecked(tx_buffer); frame.set_src_addr(self.ethernet_addr.unwrap()); @@ -1621,13 +1618,13 @@ impl<'a> InterfaceInner<'a> { } } - fn has_neighbor(&self, addr: &IpAddress, timestamp: Instant) -> bool { - match self.route(addr, timestamp) { + fn has_neighbor(&self, cx: &Context, addr: &IpAddress) -> bool { + match self.route(addr, cx.now) { Ok(_routed_addr) => { - match self.device_capabilities.medium { + match cx.caps.medium { #[cfg(feature = "medium-ethernet")] Medium::Ethernet => self.neighbor_cache.as_ref().unwrap() - .lookup(&_routed_addr, timestamp) + .lookup(&_routed_addr, cx.now) .found(), #[cfg(feature = "medium-ip")] Medium::Ip => true, @@ -1638,7 +1635,7 @@ impl<'a> InterfaceInner<'a> { } #[cfg(feature = "medium-ethernet")] - fn lookup_hardware_addr(&mut self, tx_token: Tx, timestamp: Instant, + fn lookup_hardware_addr(&mut self, cx: &Context, tx_token: Tx, src_addr: &IpAddress, dst_addr: &IpAddress) -> Result<(EthernetAddress, Tx)> where Tx: TxToken @@ -1669,9 +1666,9 @@ impl<'a> InterfaceInner<'a> { } } - let dst_addr = self.route(dst_addr, timestamp)?; + let dst_addr = self.route(dst_addr, cx.now)?; - match self.neighbor_cache.as_mut().unwrap().lookup(&dst_addr, timestamp) { + match self.neighbor_cache.as_mut().unwrap().lookup(&dst_addr, cx.now) { NeighborAnswer::Found(hardware_addr) => return Ok((hardware_addr, tx_token)), NeighborAnswer::RateLimited => @@ -1693,7 +1690,7 @@ impl<'a> InterfaceInner<'a> { target_protocol_addr: dst_addr, }; - self.dispatch_ethernet(tx_token, timestamp, arp_repr.buffer_len(), |mut frame| { + self.dispatch_ethernet(cx, tx_token, arp_repr.buffer_len(), |mut frame| { frame.set_dst_addr(EthernetAddress::BROADCAST); frame.set_ethertype(EthernetProtocol::Arp); @@ -1722,29 +1719,28 @@ impl<'a> InterfaceInner<'a> { solicit, )); - self.dispatch_ip(tx_token, timestamp, packet)?; + self.dispatch_ip(cx, tx_token, packet)?; } _ => () } // The request got dispatched, limit the rate on the cache. - self.neighbor_cache.as_mut().unwrap().limit_rate(timestamp); + self.neighbor_cache.as_mut().unwrap().limit_rate(cx.now); Err(Error::Unaddressable) } - fn dispatch_ip(&mut self, tx_token: Tx, timestamp: Instant, + fn dispatch_ip(&mut self, cx: &Context, tx_token: Tx, packet: IpPacket) -> Result<()> { let ip_repr = packet.ip_repr().lower(&self.ip_addrs)?; - let caps = self.device_capabilities.clone(); - match self.device_capabilities.medium { + match cx.caps.medium { #[cfg(feature = "medium-ethernet")] Medium::Ethernet => { let (dst_hardware_addr, tx_token) = - self.lookup_hardware_addr(tx_token, timestamp, + self.lookup_hardware_addr(cx, tx_token, &ip_repr.src_addr(), &ip_repr.dst_addr())?; - self.dispatch_ethernet(tx_token, timestamp, ip_repr.total_len(), |mut frame| { + self.dispatch_ethernet(cx, tx_token, ip_repr.total_len(), |mut frame| { frame.set_dst_addr(dst_hardware_addr); match ip_repr { #[cfg(feature = "proto-ipv4")] @@ -1754,22 +1750,22 @@ impl<'a> InterfaceInner<'a> { _ => return } - ip_repr.emit(frame.payload_mut(), &caps.checksum); + ip_repr.emit(frame.payload_mut(), &cx.caps.checksum); let payload = &mut frame.payload_mut()[ip_repr.buffer_len()..]; - packet.emit_payload(ip_repr, payload, &caps); + packet.emit_payload(ip_repr, payload, &cx.caps); }) } #[cfg(feature = "medium-ip")] Medium::Ip => { let tx_len = ip_repr.total_len(); - tx_token.consume(timestamp, tx_len, |mut tx_buffer| { + tx_token.consume(cx.now, tx_len, |mut tx_buffer| { debug_assert!(tx_buffer.as_ref().len() == tx_len); - ip_repr.emit(&mut tx_buffer, &caps.checksum); + ip_repr.emit(&mut tx_buffer, &cx.caps.checksum); let payload = &mut tx_buffer[ip_repr.buffer_len()..]; - packet.emit_payload(ip_repr, payload, &caps); + packet.emit_payload(ip_repr, payload, &cx.caps); Ok(()) }) @@ -1949,7 +1945,8 @@ mod test { // Ensure that the unknown protocol frame does not trigger an // ICMP error response when the destination address is a // broadcast address - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_ipv4(&cx, &mut socket_set, &frame), Ok(None)); } @@ -1978,7 +1975,8 @@ mod test { // Ensure that the unknown protocol frame does not trigger an // ICMP error response when the destination address is a // broadcast address - assert_eq!(iface.inner.process_ipv6(&mut socket_set, Instant::from_millis(0), &frame), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_ipv6(&cx, &mut socket_set, &frame), Ok(None)); } @@ -2028,7 +2026,8 @@ mod test { // Ensure that the unknown protocol triggers an error response. // And we correctly handle no payload. - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_ipv4(&cx, &mut socket_set, &frame), Ok(Some(expected_repr))); } @@ -2128,7 +2127,8 @@ mod test { // Ensure that the unknown protocol triggers an error response. // And we correctly handle no payload. - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr, false, data), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_udp(&cx, &mut socket_set, ip_repr, false, data), Ok(Some(expected_repr))); let ip_repr = IpRepr::Ipv4(Ipv4Repr { @@ -2148,7 +2148,7 @@ mod test { // Ensure that the port unreachable error does not trigger an // ICMP error response when the destination address is a // broadcast address and no socket is bound to the port. - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr, + assert_eq!(iface.inner.process_udp(&cx, &mut socket_set, ip_repr, false, packet_broadcast.into_inner()), Ok(None)); } @@ -2212,7 +2212,8 @@ mod test { &ChecksumCapabilities::default()); // Packet should be handled by bound UDP socket - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr, false, packet.into_inner()), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_udp(&cx, &mut socket_set, ip_repr, false, packet.into_inner()), Ok(None)); { @@ -2276,7 +2277,8 @@ mod test { }; let expected_packet = IpPacket::Icmpv4((expected_ipv4_repr, expected_icmpv4_repr)); - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_ipv4(&cx, &mut socket_set, &frame), Ok(Some(expected_packet))); } @@ -2365,14 +2367,16 @@ mod test { payload_len: expected_icmp_repr.buffer_len() }; + let cx = iface.context(Instant::from_secs(0)); + // The expected packet does not exceed the IPV4_MIN_MTU assert_eq!(expected_ip_repr.buffer_len() + expected_icmp_repr.buffer_len(), MIN_MTU); // The expected packet and the generated packet are equal #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr.into(), false, payload), + assert_eq!(iface.inner.process_udp(&cx, &mut socket_set, ip_repr.into(), false, payload), Ok(Some(IpPacket::Icmpv4((expected_ip_repr, expected_icmp_repr))))); #[cfg(feature = "proto-ipv6")] - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr.into(), false, payload), + assert_eq!(iface.inner.process_udp(&cx, &mut socket_set, ip_repr.into(), false, payload), Ok(Some(IpPacket::Icmpv6((expected_ip_repr, expected_icmp_repr))))); } @@ -2405,8 +2409,10 @@ mod test { repr.emit(&mut packet); } + let cx = iface.context(Instant::from_secs(0)); + // Ensure an ARP Request for us triggers an ARP Reply - assert_eq!(iface.inner.process_ethernet(&mut socket_set, Instant::from_millis(0), frame.into_inner()), + assert_eq!(iface.inner.process_ethernet(&cx, &mut socket_set, frame.into_inner()), Ok(Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { operation: ArpOperation::Reply, source_hardware_addr: local_hw_addr, @@ -2416,7 +2422,7 @@ mod test { })))); // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), + assert_eq!(iface.inner.lookup_hardware_addr(&cx, MockTxToken, &IpAddress::Ipv4(local_ip_addr), &IpAddress::Ipv4(remote_ip_addr)), Ok((remote_hw_addr, MockTxToken))); } @@ -2471,12 +2477,14 @@ mod test { payload_len: icmpv6_expected.buffer_len() }; + let cx = iface.context(Instant::from_secs(0)); + // Ensure an Neighbor Solicitation triggers a Neighbor Advertisement - assert_eq!(iface.inner.process_ethernet(&mut socket_set, Instant::from_millis(0), frame.into_inner()), + assert_eq!(iface.inner.process_ethernet(&cx, &mut socket_set, frame.into_inner()), Ok(Some(EthernetPacket::Ip(IpPacket::Icmpv6((ipv6_expected, icmpv6_expected)))))); // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), + assert_eq!(iface.inner.lookup_hardware_addr(&cx, MockTxToken, &IpAddress::Ipv6(local_ip_addr), &IpAddress::Ipv6(remote_ip_addr)), Ok((remote_hw_addr, MockTxToken))); } @@ -2508,12 +2516,14 @@ mod test { repr.emit(&mut packet); } + let cx = iface.context(Instant::from_secs(0)); + // Ensure an ARP Request for someone else does not trigger an ARP Reply - assert_eq!(iface.inner.process_ethernet(&mut socket_set, Instant::from_millis(0), frame.into_inner()), + assert_eq!(iface.inner.process_ethernet(&cx, &mut socket_set, frame.into_inner()), Ok(None)); // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), + assert_eq!(iface.inner.lookup_hardware_addr(&cx, MockTxToken, &IpAddress::Ipv4(Ipv4Address([0x7f, 0x00, 0x00, 0x01])), &IpAddress::Ipv4(remote_ip_addr)), Ok((remote_hw_addr, MockTxToken))); @@ -2573,7 +2583,8 @@ mod test { dst_addr: ipv4_repr.src_addr, ..ipv4_repr }; - assert_eq!(iface.inner.process_icmpv4(&mut socket_set, ip_repr, icmp_data), + let cx = iface.context(Instant::from_secs(0)); + assert_eq!(iface.inner.process_icmpv4(&cx, &mut socket_set, ip_repr, icmp_data), Ok(Some(IpPacket::Icmpv4((ipv4_reply, echo_reply))))); { @@ -2656,9 +2667,11 @@ mod test { hop_limit: 0x40, }; + let cx = iface.context(Instant::from_secs(0)); + // Ensure the unknown next header causes a ICMPv6 Parameter Problem // error message to be sent to the sender. - assert_eq!(iface.inner.process_ipv6(&mut socket_set, Instant::from_millis(0), &frame), + assert_eq!(iface.inner.process_ipv6(&cx, &mut socket_set, &frame), Ok(Some(IpPacket::Icmpv6((reply_ipv6_repr, reply_icmp_repr))))); } @@ -2739,7 +2752,8 @@ mod test { // loopback have been processed, including responses to // GENERAL_QUERY_BYTES. Therefore `recv_all()` would return 0 // pkts that could be checked. - iface.socket_ingress(&mut socket_set, timestamp); + let cx = iface.context(timestamp); + iface.socket_ingress(&cx, &mut socket_set); // Leave multicast groups let timestamp = Instant::now(); @@ -2811,7 +2825,8 @@ mod test { Ipv4Packet::new_unchecked(&bytes) }; - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), + let cx = iface.context(Instant::from_millis(0)); + assert_eq!(iface.inner.process_ipv4(&cx, &mut socket_set, &frame), Ok(None)); } @@ -2869,7 +2884,8 @@ mod test { Ipv4Packet::new_unchecked(&bytes) }; - let frame = iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame); + let cx = iface.context(Instant::from_millis(0)); + let frame = iface.inner.process_ipv4(&cx, &mut socket_set, &frame); // because the packet could not be handled we should send an Icmp message assert!(match frame { @@ -2945,7 +2961,8 @@ mod test { Ipv4Packet::new_unchecked(&bytes) }; - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), + let cx = iface.context(Instant::from_millis(0)); + assert_eq!(iface.inner.process_ipv4(&cx, &mut socket_set, &frame), Ok(None)); { diff --git a/src/lib.rs b/src/lib.rs index dac1a06..4c8e93f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,6 +105,15 @@ compile_error!("You must enable at least one of the following features: proto-ip ))] compile_error!("If you enable the socket feature, you must enable at least one of the following features: socket-raw, socket-udp, socket-tcp, socket-icmp"); +#[cfg(all( + feature = "socket", + not(any( + feature = "medium-ethernet", + feature = "medium-ip", + )) +))] +compile_error!("If you enable the socket feature, you must enable at least one of the following features: medium-ip, medium-ethernet"); + #[cfg(all(feature = "defmt", feature = "log"))] compile_error!("You must enable at most one of the following features: defmt, log"); diff --git a/src/phy/mod.rs b/src/phy/mod.rs index c048677..50cd774 100644 --- a/src/phy/mod.rs +++ b/src/phy/mod.rs @@ -230,6 +230,17 @@ pub struct DeviceCapabilities { pub checksum: ChecksumCapabilities, } +impl DeviceCapabilities { + pub fn ip_mtu(&self) -> usize { + match self.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => self.max_transmission_unit - crate::wire::EthernetFrame::<&[u8]>::header_len(), + #[cfg(feature = "medium-ip")] + Medium::Ip => self.max_transmission_unit, + } + } +} + /// Type of medium of a device. #[derive(Debug, Eq, PartialEq, Copy, Clone)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] diff --git a/src/socket/dhcpv4.rs b/src/socket/dhcpv4.rs index c60f988..7caa16e 100644 --- a/src/socket/dhcpv4.rs +++ b/src/socket/dhcpv4.rs @@ -1,10 +1,10 @@ use crate::{Error, Result}; -use crate::wire::{EthernetAddress, IpProtocol, IpAddress, +use crate::wire::{IpProtocol, IpAddress, Ipv4Cidr, Ipv4Address, Ipv4Repr, UdpRepr, UDP_HEADER_LEN, DhcpPacket, DhcpRepr, DhcpMessageType, DHCP_CLIENT_PORT, DHCP_SERVER_PORT, DHCP_MAX_DNS_SERVER_COUNT}; use crate::wire::dhcpv4::{field as dhcpv4_field}; -use crate::socket::SocketMeta; +use crate::socket::{SocketMeta, Context}; use crate::time::{Instant, Duration}; use crate::socket::SocketHandle; @@ -150,7 +150,7 @@ impl Dhcpv4Socket { self.max_lease_duration = max_lease_duration; } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { let t = match &self.state { ClientState::Discovering(state) => state.retry_at, ClientState::Requesting(state) => state.retry_at, @@ -159,7 +159,7 @@ impl Dhcpv4Socket { PollAt::Time(t) } - pub(crate) fn process(&mut self, now: Instant, ethernet_addr: EthernetAddress, ip_repr: &Ipv4Repr, repr: &UdpRepr, payload: &[u8]) -> Result<()> { + pub(crate) fn process(&mut self, cx: &Context, ip_repr: &Ipv4Repr, repr: &UdpRepr, payload: &[u8]) -> Result<()> { let src_ip = ip_repr.src_addr; // This is enforced in interface.rs. @@ -179,7 +179,7 @@ impl Dhcpv4Socket { return Ok(()); } }; - if dhcp_repr.client_hardware_address != ethernet_addr { return Ok(()) } + if dhcp_repr.client_hardware_address != cx.ethernet_address.unwrap() { return Ok(()) } if dhcp_repr.transaction_id != self.transaction_id { return Ok(()) } let server_identifier = match dhcp_repr.server_identifier { Some(server_identifier) => server_identifier, @@ -199,7 +199,7 @@ impl Dhcpv4Socket { } self.state = ClientState::Requesting(RequestState { - retry_at: now, + retry_at: cx.now, retry: 0, server: ServerInfo { address: src_ip, @@ -209,7 +209,7 @@ impl Dhcpv4Socket { }); } (ClientState::Requesting(state), DhcpMessageType::Ack) => { - if let Some((config, renew_at, expires_at)) = Self::parse_ack(now, &dhcp_repr, self.max_lease_duration) { + if let Some((config, renew_at, expires_at)) = Self::parse_ack(cx.now, &dhcp_repr, self.max_lease_duration) { self.config_changed = true; self.state = ClientState::Renewing(RenewState{ server: state.server, @@ -223,7 +223,7 @@ impl Dhcpv4Socket { self.reset(); } (ClientState::Renewing(state), DhcpMessageType::Ack) => { - if let Some((config, renew_at, expires_at)) = Self::parse_ack(now, &dhcp_repr, self.max_lease_duration) { + if let Some((config, renew_at, expires_at)) = Self::parse_ack(cx.now, &dhcp_repr, self.max_lease_duration) { state.renew_at = renew_at; state.expires_at = expires_at; if state.config != config { @@ -298,9 +298,13 @@ impl Dhcpv4Socket { Some((config, renew_at, expires_at)) } - pub(crate) fn dispatch(&mut self, now: Instant, ethernet_addr: EthernetAddress, ip_mtu: usize, emit: F) -> Result<()> + pub(crate) fn dispatch(&mut self, cx: &Context, emit: F) -> Result<()> where F: FnOnce((Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<()> { + // note: Dhcpv4Socket is only usable in ethernet mediums, so the + // unwrap can never fail. + let ethernet_addr = cx.ethernet_address.unwrap(); + // Worst case biggest IPv4 header length. // 0x0f * 4 = 60 bytes. const MAX_IPV4_HEADER_LEN: usize = 60; @@ -324,7 +328,7 @@ impl Dhcpv4Socket { client_identifier: Some(ethernet_addr), server_identifier: None, parameter_request_list: Some(PARAMETER_REQUEST_LIST), - max_size: Some((ip_mtu - MAX_IPV4_HEADER_LEN - UDP_HEADER_LEN) as u16), + max_size: Some((cx.caps.ip_mtu() - MAX_IPV4_HEADER_LEN - UDP_HEADER_LEN) as u16), lease_duration: None, dns_servers: None, }; @@ -344,7 +348,7 @@ impl Dhcpv4Socket { match &mut self.state { ClientState::Discovering(state) => { - if now < state.retry_at { + if cx.now < state.retry_at { return Err(Error::Exhausted) } @@ -354,12 +358,12 @@ impl Dhcpv4Socket { emit((ipv4_repr, udp_repr, dhcp_repr))?; // Update state AFTER the packet has been successfully sent. - state.retry_at = now + DISCOVER_TIMEOUT; + state.retry_at = cx.now + DISCOVER_TIMEOUT; self.transaction_id = next_transaction_id; Ok(()) } ClientState::Requesting(state) => { - if now < state.retry_at { + if cx.now < state.retry_at { return Err(Error::Exhausted) } @@ -380,21 +384,21 @@ impl Dhcpv4Socket { emit((ipv4_repr, udp_repr, dhcp_repr))?; // Exponential backoff: Double every 2 retries. - state.retry_at = now + (REQUEST_TIMEOUT << (state.retry as u32 / 2)); + state.retry_at = cx.now + (REQUEST_TIMEOUT << (state.retry as u32 / 2)); state.retry += 1; self.transaction_id = next_transaction_id; Ok(()) } ClientState::Renewing(state) => { - if state.expires_at <= now { + if state.expires_at <= cx.now { net_debug!("DHCP lease expired"); self.reset(); // return Ok so we get polled again return Ok(()) } - if now < state.renew_at { + if cx.now < state.renew_at { return Err(Error::Exhausted) } @@ -413,7 +417,7 @@ impl Dhcpv4Socket { // of the remaining time until T2 (in RENEWING state) and one-half of // the remaining lease time (in REBINDING state), down to a minimum of // 60 seconds, before retransmitting the DHCPREQUEST message. - state.renew_at = now + MIN_RENEW_TIMEOUT.max((state.expires_at - now) / 2); + state.renew_at = cx.now + MIN_RENEW_TIMEOUT.max((state.expires_at - cx.now) / 2); self.transaction_id = next_transaction_id; Ok(()) diff --git a/src/socket/icmp.rs b/src/socket/icmp.rs index 9208ca7..39c8b58 100644 --- a/src/socket/icmp.rs +++ b/src/socket/icmp.rs @@ -4,7 +4,7 @@ use core::task::Waker; use crate::{Error, Result}; use crate::phy::ChecksumCapabilities; -use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt}; +use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt, Context}; use crate::storage::{PacketBuffer, PacketMetadata}; #[cfg(feature = "async")] use crate::socket::WakerRegistration; @@ -326,8 +326,7 @@ impl<'a> IcmpSocket<'a> { /// Filter determining which packets received by the interface are appended to /// the given sockets received buffer. - pub(crate) fn accepts(&self, ip_repr: &IpRepr, icmp_repr: &IcmpRepr, - cksum: &ChecksumCapabilities) -> bool { + pub(crate) fn accepts(&self, cx: &Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) -> bool { match (&self.endpoint, icmp_repr) { // If we are bound to ICMP errors associated to a UDP port, only // accept Destination Unreachable messages with the data containing @@ -336,7 +335,7 @@ impl<'a> IcmpSocket<'a> { (&Endpoint::Udp(endpoint), &IcmpRepr::Ipv4(Icmpv4Repr::DstUnreachable { data, .. })) if endpoint.addr.is_unspecified() || endpoint.addr == ip_repr.dst_addr() => { let packet = UdpPacket::new_unchecked(data); - match UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), cksum) { + match UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), &cx.caps.checksum) { Ok(repr) => endpoint.port == repr.src_port, Err(_) => false, } @@ -345,7 +344,7 @@ impl<'a> IcmpSocket<'a> { (&Endpoint::Udp(endpoint), &IcmpRepr::Ipv6(Icmpv6Repr::DstUnreachable { data, .. })) if endpoint.addr.is_unspecified() || endpoint.addr == ip_repr.dst_addr() => { let packet = UdpPacket::new_unchecked(data); - match UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), cksum) { + match UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), &cx.caps.checksum) { Ok(repr) => endpoint.port == repr.src_port, Err(_) => false, } @@ -369,8 +368,7 @@ impl<'a> IcmpSocket<'a> { } } - pub(crate) fn process(&mut self, ip_repr: &IpRepr, icmp_repr: &IcmpRepr, - _cksum: &ChecksumCapabilities) -> Result<()> { + pub(crate) fn process(&mut self, _cx: &Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) -> Result<()> { match *icmp_repr { #[cfg(feature = "proto-ipv4")] IcmpRepr::Ipv4(ref icmp_repr) => { @@ -401,7 +399,7 @@ impl<'a> IcmpSocket<'a> { Ok(()) } - pub(crate) fn dispatch(&mut self, emit: F) -> Result<()> + pub(crate) fn dispatch(&mut self, _cx: &Context, emit: F) -> Result<()> where F: FnOnce((IpRepr, IcmpRepr)) -> Result<()> { let handle = self.meta.handle; @@ -447,7 +445,7 @@ impl<'a> IcmpSocket<'a> { Ok(()) } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { if self.tx_buffer.is_empty() { PollAt::Ingress } else { @@ -532,7 +530,7 @@ mod test_ipv4 { let mut socket = socket(buffer(0), buffer(1)); let checksum = ChecksumCapabilities::default(); - assert_eq!(socket.dispatch(|_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Err(Error::Exhausted)); // This buffer is too long @@ -547,7 +545,7 @@ mod test_ipv4 { assert_eq!(socket.send_slice(b"123456", REMOTE_IPV4.into()), Err(Error::Exhausted)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, icmp_repr)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, icmp_repr)| { assert_eq!(ip_repr, LOCAL_IPV4_REPR); assert_eq!(icmp_repr, ECHOV4_REPR.into()); Err(Error::Unaddressable) @@ -555,7 +553,7 @@ mod test_ipv4 { // buffer is not taken off of the tx queue due to the error assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, icmp_repr)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, icmp_repr)| { assert_eq!(ip_repr, LOCAL_IPV4_REPR); assert_eq!(icmp_repr, ECHOV4_REPR.into()); Ok(()) @@ -576,7 +574,7 @@ mod test_ipv4 { s.set_hop_limit(Some(0x2a)); assert_eq!(s.send_slice(&packet.into_inner()[..], REMOTE_IPV4.into()), Ok(())); - assert_eq!(s.dispatch(|(ip_repr, _)| { + assert_eq!(s.dispatch(&Context::DUMMY, |(ip_repr, _)| { assert_eq!(ip_repr, IpRepr::Ipv4(Ipv4Repr { src_addr: Ipv4Address::UNSPECIFIED, dst_addr: REMOTE_IPV4, @@ -603,13 +601,13 @@ mod test_ipv4 { ECHOV4_REPR.emit(&mut packet, &checksum); let data = &packet.into_inner()[..]; - assert!(socket.accepts(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &checksum)); - assert_eq!(socket.process(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &checksum), + assert!(socket.accepts(&Context::DUMMY, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into())); + assert_eq!(socket.process(&Context::DUMMY, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into()), Ok(())); assert!(socket.can_recv()); - assert!(socket.accepts(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &checksum)); - assert_eq!(socket.process(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &checksum), + assert!(socket.accepts(&Context::DUMMY, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into())); + assert_eq!(socket.process(&Context::DUMMY, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into()), Err(Error::Exhausted)); assert_eq!(socket.recv(), Ok((&data[..], REMOTE_IPV4.into()))); @@ -633,7 +631,7 @@ mod test_ipv4 { // Ensure that a packet with an identifier that isn't the bound // ID is not accepted - assert!(!socket.accepts(&REMOTE_IPV4_REPR, &icmp_repr.into(), &checksum)); + assert!(!socket.accepts(&Context::DUMMY, &REMOTE_IPV4_REPR, &icmp_repr.into())); } #[test] @@ -678,8 +676,8 @@ mod test_ipv4 { // Ensure we can accept ICMP error response to the bound // UDP port - assert!(socket.accepts(&ip_repr, &icmp_repr.into(), &checksum)); - assert_eq!(socket.process(&ip_repr, &icmp_repr.into(), &checksum), + assert!(socket.accepts(&Context::DUMMY, &ip_repr, &icmp_repr.into())); + assert_eq!(socket.process(&Context::DUMMY, &ip_repr, &icmp_repr.into()), Ok(())); assert!(socket.can_recv()); @@ -737,7 +735,7 @@ mod test_ipv6 { let mut socket = socket(buffer(0), buffer(1)); let checksum = ChecksumCapabilities::default(); - assert_eq!(socket.dispatch(|_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Err(Error::Exhausted)); // This buffer is too long @@ -752,7 +750,7 @@ mod test_ipv6 { assert_eq!(socket.send_slice(b"123456", REMOTE_IPV6.into()), Err(Error::Exhausted)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, icmp_repr)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, icmp_repr)| { assert_eq!(ip_repr, LOCAL_IPV6_REPR); assert_eq!(icmp_repr, ECHOV6_REPR.into()); Err(Error::Unaddressable) @@ -760,7 +758,7 @@ mod test_ipv6 { // buffer is not taken off of the tx queue due to the error assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, icmp_repr)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, icmp_repr)| { assert_eq!(ip_repr, LOCAL_IPV6_REPR); assert_eq!(icmp_repr, ECHOV6_REPR.into()); Ok(()) @@ -781,7 +779,7 @@ mod test_ipv6 { s.set_hop_limit(Some(0x2a)); assert_eq!(s.send_slice(&packet.into_inner()[..], REMOTE_IPV6.into()), Ok(())); - assert_eq!(s.dispatch(|(ip_repr, _)| { + assert_eq!(s.dispatch(&Context::DUMMY, |(ip_repr, _)| { assert_eq!(ip_repr, IpRepr::Ipv6(Ipv6Repr { src_addr: Ipv6Address::UNSPECIFIED, dst_addr: REMOTE_IPV6, @@ -808,13 +806,13 @@ mod test_ipv6 { ECHOV6_REPR.emit(&LOCAL_IPV6.into(), &REMOTE_IPV6.into(), &mut packet, &checksum); let data = &packet.into_inner()[..]; - assert!(socket.accepts(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &checksum)); - assert_eq!(socket.process(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &checksum), + assert!(socket.accepts(&Context::DUMMY, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + assert_eq!(socket.process(&Context::DUMMY, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()), Ok(())); assert!(socket.can_recv()); - assert!(socket.accepts(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &checksum)); - assert_eq!(socket.process(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &checksum), + assert!(socket.accepts(&Context::DUMMY, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + assert_eq!(socket.process(&Context::DUMMY, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()), Err(Error::Exhausted)); assert_eq!(socket.recv(), Ok((&data[..], REMOTE_IPV6.into()))); @@ -838,7 +836,7 @@ mod test_ipv6 { // Ensure that a packet with an identifier that isn't the bound // ID is not accepted - assert!(!socket.accepts(&REMOTE_IPV6_REPR, &icmp_repr.into(), &checksum)); + assert!(!socket.accepts(&Context::DUMMY, &REMOTE_IPV6_REPR, &icmp_repr.into())); } #[test] @@ -883,8 +881,8 @@ mod test_ipv6 { // Ensure we can accept ICMP error response to the bound // UDP port - assert!(socket.accepts(&ip_repr, &icmp_repr.into(), &checksum)); - assert_eq!(socket.process(&ip_repr, &icmp_repr.into(), &checksum), + assert!(socket.accepts(&Context::DUMMY, &ip_repr, &icmp_repr.into())); + assert_eq!(socket.process(&Context::DUMMY, &ip_repr, &icmp_repr.into()), Ok(())); assert!(socket.can_recv()); diff --git a/src/socket/mod.rs b/src/socket/mod.rs index ebd0c42..1c01efb 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -11,6 +11,7 @@ The interface implemented by this module uses explicit buffering: you decide on size for a buffer, allocate it, and let the networking stack use it. */ +use crate::phy::DeviceCapabilities; use crate::time::Instant; mod meta; @@ -138,8 +139,8 @@ impl<'a> Socket<'a> { dispatch_socket!(mut self, |socket| &mut socket.meta) } - pub(crate) fn poll_at(&self) -> PollAt { - dispatch_socket!(self, |socket| socket.poll_at()) + pub(crate) fn poll_at(&self, cx: &Context) -> PollAt { + dispatch_socket!(self, |socket| socket.poll_at(cx)) } } @@ -180,3 +181,42 @@ from_socket!(UdpSocket<'a>, Udp); from_socket!(TcpSocket<'a>, Tcp); #[cfg(feature = "socket-dhcpv4")] from_socket!(Dhcpv4Socket, Dhcpv4); + +/// Data passed to sockets when processing. +#[derive(Clone, Debug)] +pub(crate) struct Context { + pub now: Instant, + #[cfg(feature = "medium-ethernet")] + pub ethernet_address: Option, + pub caps: DeviceCapabilities, +} + +#[cfg(test)] +impl Context { + + pub(crate) const DUMMY: Context = Context { + caps: DeviceCapabilities { + #[cfg(feature = "medium-ethernet")] + medium: crate::phy::Medium::Ethernet, + #[cfg(not(feature = "medium-ethernet"))] + medium: crate::phy::Medium::Ip, + checksum: crate::phy::ChecksumCapabilities{ + #[cfg(feature = "proto-ipv4")] + icmpv4: crate::phy::Checksum::Both, + #[cfg(feature = "proto-ipv6")] + icmpv6: crate::phy::Checksum::Both, + ipv4: crate::phy::Checksum::Both, + tcp: crate::phy::Checksum::Both, + udp: crate::phy::Checksum::Both, + }, + max_burst_size: None, + #[cfg(feature = "medium-ethernet")] + max_transmission_unit: 1514, + #[cfg(not(feature = "medium-ethernet"))] + max_transmission_unit: 1500, + }, + ethernet_address: None, + now: Instant{millis: 0}, + }; + +} \ No newline at end of file diff --git a/src/socket/raw.rs b/src/socket/raw.rs index e4e60ed..0baa702 100644 --- a/src/socket/raw.rs +++ b/src/socket/raw.rs @@ -4,7 +4,7 @@ use core::task::Waker; use crate::{Error, Result}; use crate::phy::ChecksumCapabilities; -use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt}; +use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt, Context}; use crate::storage::{PacketBuffer, PacketMetadata}; #[cfg(feature = "async")] use crate::socket::WakerRegistration; @@ -206,14 +206,13 @@ impl<'a> RawSocket<'a> { true } - pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8], - checksum_caps: &ChecksumCapabilities) -> Result<()> { + pub(crate) fn process(&mut self, cx: &Context, ip_repr: &IpRepr, payload: &[u8]) -> Result<()> { debug_assert!(self.accepts(ip_repr)); let header_len = ip_repr.buffer_len(); let total_len = header_len + payload.len(); let packet_buf = self.rx_buffer.enqueue(total_len, ())?; - ip_repr.emit(&mut packet_buf[..header_len], &checksum_caps); + ip_repr.emit(&mut packet_buf[..header_len], &cx.caps.checksum); packet_buf[header_len..].copy_from_slice(payload); net_trace!("{}:{}:{}: receiving {} octets", @@ -226,8 +225,7 @@ impl<'a> RawSocket<'a> { Ok(()) } - pub(crate) fn dispatch(&mut self, checksum_caps: &ChecksumCapabilities, emit: F) -> - Result<()> + pub(crate) fn dispatch(&mut self, cx: &Context, emit: F) -> Result<()> where F: FnOnce((IpRepr, &[u8])) -> Result<()> { fn prepare<'a>(protocol: IpProtocol, buffer: &'a mut [u8], _checksum_caps: &ChecksumCapabilities) -> Result<(IpRepr, &'a [u8])> { @@ -264,7 +262,7 @@ impl<'a> RawSocket<'a> { let ip_protocol = self.ip_protocol; let ip_version = self.ip_version; self.tx_buffer.dequeue_with(|&mut (), packet_buf| { - match prepare(ip_protocol, packet_buf, &checksum_caps) { + match prepare(ip_protocol, packet_buf, &cx.caps.checksum) { Ok((ip_repr, raw_packet)) => { net_trace!("{}:{}:{}: sending {} octets", handle, ip_version, ip_protocol, @@ -287,7 +285,7 @@ impl<'a> RawSocket<'a> { Ok(()) } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { if self.tx_buffer.is_empty() { PollAt::Ingress } else { @@ -401,25 +399,24 @@ mod test { #[test] fn test_send_dispatch() { - let checksum_caps = &ChecksumCapabilities::default(); let mut socket = $socket(buffer(0), buffer(1)); assert!(socket.can_send()); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Err(Error::Exhausted)); assert_eq!(socket.send_slice(&$packet[..]), Ok(())); assert_eq!(socket.send_slice(b""), Err(Error::Exhausted)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&checksum_caps, |(ip_repr, ip_payload)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, ip_payload)| { assert_eq!(ip_repr, $hdr); assert_eq!(ip_payload, &$payload); Err(Error::Unaddressable) }), Err(Error::Unaddressable)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&checksum_caps, |(ip_repr, ip_payload)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, ip_payload)| { assert_eq!(ip_repr, $hdr); assert_eq!(ip_payload, &$payload); Ok(()) @@ -432,8 +429,7 @@ mod test { let mut socket = $socket(buffer(1), buffer(0)); assert!(socket.accepts(&$hdr)); - assert_eq!(socket.process(&$hdr, &$payload, - &ChecksumCapabilities::default()), Ok(())); + assert_eq!(socket.process(&Context::DUMMY, &$hdr, &$payload), Ok(())); let mut slice = [0; 4]; assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); @@ -448,8 +444,7 @@ mod test { buffer[..$packet.len()].copy_from_slice(&$packet[..]); assert!(socket.accepts(&$hdr)); - assert_eq!(socket.process(&$hdr, &buffer, &ChecksumCapabilities::default()), - Err(Error::Truncated)); + assert_eq!(socket.process(&Context::DUMMY, &$hdr, &buffer), Err(Error::Truncated)); } } } @@ -467,7 +462,6 @@ mod test { #[test] #[cfg(feature = "proto-ipv4")] fn test_send_illegal() { - let checksum_caps = &ChecksumCapabilities::default(); #[cfg(feature = "proto-ipv4")] { let mut socket = ipv4_locals::socket(buffer(0), buffer(2)); @@ -476,14 +470,14 @@ mod test { Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6); assert_eq!(socket.send_slice(&wrong_version[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Ok(())); let mut wrong_protocol = ipv4_locals::PACKET_BYTES; Ipv4Packet::new_unchecked(&mut wrong_protocol).set_protocol(IpProtocol::Tcp); assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Ok(())); } #[cfg(feature = "proto-ipv6")] @@ -494,14 +488,14 @@ mod test { Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4); assert_eq!(socket.send_slice(&wrong_version[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Ok(())); let mut wrong_protocol = ipv6_locals::PACKET_BYTES; Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp); assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Ok(())); } } @@ -518,14 +512,12 @@ mod test { assert_eq!(socket.recv(), Err(Error::Exhausted)); assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), + assert_eq!(socket.process(&Context::DUMMY, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD), Ok(())); assert!(socket.can_recv()); assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), + assert_eq!(socket.process(&Context::DUMMY, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD), Err(Error::Exhausted)); assert_eq!(socket.recv(), Ok(&cksumd_packet[..])); assert!(!socket.can_recv()); @@ -537,14 +529,12 @@ mod test { assert_eq!(socket.recv(), Err(Error::Exhausted)); assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), + assert_eq!(socket.process(&Context::DUMMY, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD), Ok(())); assert!(socket.can_recv()); assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), + assert_eq!(socket.process(&Context::DUMMY, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD), Err(Error::Exhausted)); assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..])); assert!(!socket.can_recv()); diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index b1b00ee..35e60eb 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -8,11 +8,11 @@ use core::task::Waker; use crate::{Error, Result}; use crate::time::{Duration, Instant}; -use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt}; +use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt, Context}; use crate::storage::{Assembler, RingBuffer}; #[cfg(feature = "async")] use crate::socket::WakerRegistration; -use crate::wire::{IpProtocol, IpRepr, IpAddress, IpEndpoint, TcpSeqNumber, TcpRepr, TcpControl}; +use crate::wire::{IpProtocol, IpRepr, IpAddress, IpEndpoint, TcpSeqNumber, TcpRepr, TcpControl, TCP_HEADER_LEN}; /// A TCP socket ring buffer. pub type SocketBuffer<'a> = RingBuffer<'a, u8>; @@ -354,6 +354,9 @@ pub struct TcpSocket<'a> { /// ACK or window updates (ie, no data) won't be sent until expiry. ack_delay_until: Option, + /// Nagle's Algorithm enabled. + nagle: bool, + #[cfg(feature = "async")] rx_waker: WakerRegistration, #[cfg(feature = "async")] @@ -412,6 +415,7 @@ impl<'a> TcpSocket<'a> { local_rx_dup_acks: 0, ack_delay: Some(ACK_DELAY_DEFAULT), ack_delay_until: None, + nagle: true, #[cfg(feature = "async")] rx_waker: WakerRegistration::new(), @@ -475,6 +479,13 @@ impl<'a> TcpSocket<'a> { self.ack_delay } + /// Return whether Nagle's Algorithm is enabled. + /// + /// See also the [set_nagle_enabled](#method.set_nagle_enabled) method. + pub fn nagle_enabled(&self) -> Option { + self.ack_delay + } + /// Return the current window field value, including scaling according to RFC 1323. /// /// Used in internal calculations as well as packet generation. @@ -507,6 +518,22 @@ impl<'a> TcpSocket<'a> { self.ack_delay = duration } + /// Enable or disable Nagle's Algorithm. + /// + /// Also known as "tinygram prevention". By default, it is enabled. + /// Disabling it is equivalent to Linux's TCP_NODELAY flag. + /// + /// When enabled, Nagle's Algorithm prevents sending segments smaller than MSS if + /// there is data in flight (sent but not acknowledged). In other words, it ensures + /// at most only one segment smaller than MSS is in flight at a time. + /// + /// It ensures better network utilization by preventing sending many very small packets, + /// at the cost of increased latency in some situations, particularly when the remote peer + /// has ACK delay enabled. + pub fn set_nagle_enabled(&mut self, enabled: bool) { + self.nagle = enabled + } + /// Return the keep-alive interval. /// /// See also the [set_keep_alive](#method.set_keep_alive) method. @@ -609,6 +636,7 @@ impl<'a> TcpSocket<'a> { self.remote_last_ts = None; self.ack_delay = Some(ACK_DELAY_DEFAULT); self.ack_delay_until = None; + self.nagle = true; #[cfg(feature = "async")] { @@ -1132,7 +1160,7 @@ impl<'a> TcpSocket<'a> { true } - pub(crate) fn process(&mut self, timestamp: Instant, ip_repr: &IpRepr, repr: &TcpRepr) -> + pub(crate) fn process(&mut self, cx: &Context, ip_repr: &IpRepr, repr: &TcpRepr) -> Result)>> { debug_assert!(self.accepts(ip_repr, repr)); @@ -1267,7 +1295,7 @@ impl<'a> TcpSocket<'a> { // If we're in the TIME-WAIT state, restart the TIME-WAIT timeout, since // the remote end may not have realized we've closed the connection. if self.state == State::TimeWait { - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now); } return Ok(Some(self.ack_reply(ip_repr, &repr))) @@ -1296,7 +1324,7 @@ impl<'a> TcpSocket<'a> { ack_of_fin = true; } - self.rtte.on_ack(timestamp, ack_number); + self.rtte.on_ack(cx.now, ack_number); } } @@ -1351,18 +1379,18 @@ impl<'a> TcpSocket<'a> { self.remote_mss = max_seg_size as usize } self.remote_win_scale = repr.window_scale; - // No window scaling means don't do any window shifting + // Remote doesn't support window scaling, don't do it. if self.remote_win_scale.is_none() { self.remote_win_shift = 0; } self.set_state(State::SynReceived); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // ACK packets in the SYN-RECEIVED state change it to ESTABLISHED. (State::SynReceived, TcpControl::None) => { self.set_state(State::Established); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // FIN packets in the SYN-RECEIVED state change it to CLOSE-WAIT. @@ -1372,7 +1400,7 @@ impl<'a> TcpSocket<'a> { self.remote_seq_no += 1; self.rx_fin_received = true; self.set_state(State::CloseWait); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED. @@ -1383,18 +1411,24 @@ impl<'a> TcpSocket<'a> { self.remote_seq_no = repr.seq_number + 1; self.remote_last_seq = self.local_seq_no + 1; self.remote_last_ack = Some(repr.seq_number); + self.remote_win_scale = repr.window_scale; + // Remote doesn't support window scaling, don't do it. + if self.remote_win_scale.is_none() { + self.remote_win_shift = 0; + } + if let Some(max_seg_size) = repr.max_seg_size { self.remote_mss = max_seg_size as usize; } self.set_state(State::Established); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // ACK packets in ESTABLISHED state reset the retransmit timer, // except for duplicate ACK packets which preserve it. (State::Established, TcpControl::None) => { if !self.timer.is_retransmit() || ack_len != 0 { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } }, @@ -1403,7 +1437,7 @@ impl<'a> TcpSocket<'a> { self.remote_seq_no += 1; self.rx_fin_received = true; self.set_state(State::CloseWait); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2, if we've already @@ -1412,7 +1446,7 @@ impl<'a> TcpSocket<'a> { if ack_of_fin { self.set_state(State::FinWait2); } - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // FIN packets in FIN-WAIT-1 state change it to CLOSING, or to TIME-WAIT @@ -1422,16 +1456,16 @@ impl<'a> TcpSocket<'a> { self.rx_fin_received = true; if ack_of_fin { self.set_state(State::TimeWait); - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now); } else { self.set_state(State::Closing); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } } // Data packets in FIN-WAIT-2 reset the idle timer. (State::FinWait2, TcpControl::None) => { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // FIN packets in FIN-WAIT-2 state change it to TIME-WAIT. @@ -1439,22 +1473,22 @@ impl<'a> TcpSocket<'a> { self.remote_seq_no += 1; self.rx_fin_received = true; self.set_state(State::TimeWait); - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now); } // ACK packets in CLOSING state change it to TIME-WAIT. (State::Closing, TcpControl::None) => { if ack_of_fin { self.set_state(State::TimeWait); - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now); } else { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } } // ACK packets in CLOSE-WAIT state reset the retransmit timer. (State::CloseWait, TcpControl::None) => { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } // ACK packets in LAST-ACK state change it to CLOSED. @@ -1465,7 +1499,7 @@ impl<'a> TcpSocket<'a> { self.local_endpoint = IpEndpoint::default(); self.remote_endpoint = IpEndpoint::default(); } else { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now, self.keep_alive); } } @@ -1477,11 +1511,15 @@ impl<'a> TcpSocket<'a> { } // Update remote state. - self.remote_last_ts = Some(timestamp); + self.remote_last_ts = Some(cx.now); // RFC 1323: The window field (SEG.WND) in the header of every incoming segment, with the // exception of SYN segments, is left-shifted by Snd.Wind.Scale bits before updating SND.WND. - self.remote_win_len = (repr.window_len as usize) << (self.remote_win_scale.unwrap_or(0) as usize); + let scale = match repr.control { + TcpControl::Syn => 0, + _ => self.remote_win_scale.unwrap_or(0), + }; + self.remote_win_len = (repr.window_len as usize) << (scale as usize); if ack_len > 0 { // Dequeue acknowledged octets. @@ -1600,7 +1638,7 @@ impl<'a> TcpSocket<'a> { self.meta.handle, self.local_endpoint, self.remote_endpoint ); - Some(timestamp + ack_delay) + Some(cx.now + ack_delay) } // RFC1122 says "in a stream of full-sized segments there SHOULD be an ACK // for at least every second segment". @@ -1639,12 +1677,38 @@ impl<'a> TcpSocket<'a> { } } - fn seq_to_transmit(&self) -> bool { - // We can send data if we have data that: - // - hasn't been sent before - // - fits in the remote window - let can_data = self.remote_last_seq - < self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len()); + fn seq_to_transmit(&self, cx: &Context) -> bool { + let ip_header_len = match self.local_endpoint.addr { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(_) => crate::wire::IPV4_HEADER_LEN, + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_) => crate::wire::IPV6_HEADER_LEN, + IpAddress::Unspecified => unreachable!(), + }; + + // Max segment size we're able to send due to MTU limitations. + let local_mss = cx.caps.ip_mtu() - ip_header_len - TCP_HEADER_LEN; + + // The effective max segment size, taking into account our and remote's limits. + let effective_mss = local_mss.min(self.remote_mss); + + // Have we sent data that hasn't been ACKed yet? + let data_in_flight = self.remote_last_seq != self.local_seq_no; + + // max sequence number we can send. + let max_send_seq = self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len()); + + // Max amount of octets we can send. + let max_send = if max_send_seq >= self.remote_last_seq { + max_send_seq - self.remote_last_seq + } else { + 0 + }; + + // Can we send at least 1 octet? + let mut can_send = max_send != 0; + // Can we send at least 1 full segment? + let can_send_full = max_send >= effective_mss; // Do we have to send a FIN? let want_fin = match self.state { @@ -1654,6 +1718,10 @@ impl<'a> TcpSocket<'a> { _ => false, }; + if self.nagle && data_in_flight && !can_send_full { + can_send = false; + } + // Can we actually send the FIN? We can send it if: // 1. We have unsent data that fits in the remote window. // 2. We have no unsent data. @@ -1661,7 +1729,7 @@ impl<'a> TcpSocket<'a> { let can_fin = want_fin && self.remote_last_seq == self.local_seq_no + self.tx_buffer.len(); - can_data || can_fin + can_send || can_fin } fn delayed_ack_expired(&self, timestamp: Instant) -> bool { @@ -1682,13 +1750,12 @@ impl<'a> TcpSocket<'a> { fn window_to_update(&self) -> bool { match self.state { State::SynSent | State::SynReceived | State::Established | State::FinWait1 | State::FinWait2 => - (self.rx_buffer.window() >> self.remote_win_shift) as u16 > self.remote_last_win, + self.scaled_window() > self.remote_last_win, _ => false, } } - pub(crate) fn dispatch(&mut self, timestamp: Instant, ip_mtu: usize, - emit: F) -> Result<()> + pub(crate) fn dispatch(&mut self, cx: &Context, emit: F) -> Result<()> where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> { if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) } @@ -1700,17 +1767,17 @@ impl<'a> TcpSocket<'a> { // period of time, it isn't anymore, and the local endpoint is talking. // So, we start counting the timeout not from the last received packet // but from the first transmitted one. - self.remote_last_ts = Some(timestamp); + self.remote_last_ts = Some(cx.now); } // Check if any state needs to be changed because of a timer. - if self.timed_out(timestamp) { + if self.timed_out(cx.now) { // If a timeout expires, we should abort the connection. net_debug!("{}:{}:{}: timeout exceeded", self.meta.handle, self.local_endpoint, self.remote_endpoint); self.set_state(State::Closed); - } else if !self.seq_to_transmit() { - if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) { + } else if !self.seq_to_transmit(cx) { + if let Some(retransmit_delta) = self.timer.should_retransmit(cx.now) { // If a retransmit timer expired, we should resend data starting at the last ACK. net_debug!("{}:{}:{}: retransmitting at t+{}", self.meta.handle, self.local_endpoint, self.remote_endpoint, @@ -1721,15 +1788,15 @@ impl<'a> TcpSocket<'a> { } // Decide whether we're sending a packet. - if self.seq_to_transmit() { + if self.seq_to_transmit(cx) { // If we have data to transmit and it fits into partner's window, do it. net_trace!("{}:{}:{}: outgoing segment will send data or flags", self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.ack_to_transmit() && self.delayed_ack_expired(timestamp) { + } else if self.ack_to_transmit() && self.delayed_ack_expired(cx.now) { // If we have data to acknowledge, do it. net_trace!("{}:{}:{}: outgoing segment will acknowledge", self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.window_to_update() && self.delayed_ack_expired(timestamp) { + } else if self.window_to_update() && self.delayed_ack_expired(cx.now) { // If we have window length increase to advertise, do it. net_trace!("{}:{}:{}: outgoing segment will update window", self.meta.handle, self.local_endpoint, self.remote_endpoint); @@ -1737,15 +1804,15 @@ impl<'a> TcpSocket<'a> { // If we need to abort the connection, do it. net_trace!("{}:{}:{}: outgoing segment will abort connection", self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.timer.should_retransmit(timestamp).is_some() { + } else if self.timer.should_retransmit(cx.now).is_some() { // If we have packets to retransmit, do it. net_trace!("{}:{}:{}: retransmit timer expired", self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.timer.should_keep_alive(timestamp) { + } else if self.timer.should_keep_alive(cx.now) { // If we need to transmit a keep-alive packet, do it. net_trace!("{}:{}:{}: keep-alive timer expired", self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.timer.should_close(timestamp) { + } else if self.timer.should_close(cx.now) { // If we have spent enough time in the TIME-WAIT state, close the socket. net_trace!("{}:{}:{}: TIME-WAIT timer expired", self.meta.handle, self.local_endpoint, self.remote_endpoint); @@ -1795,6 +1862,8 @@ impl<'a> TcpSocket<'a> { // We transmit a SYN|ACK in the SYN-RECEIVED state. State::SynSent | State::SynReceived => { repr.control = TcpControl::Syn; + // window len must NOT be scaled in SYNs. + repr.window_len = self.rx_buffer.window().min((1<<16)-1) as u16; if self.state == State::SynSent { repr.ack_number = None; repr.window_scale = Some(self.remote_win_shift); @@ -1833,7 +1902,7 @@ impl<'a> TcpSocket<'a> { // 3. MSS we can send, determined by our MTU. let size = win_limit .min(self.remote_mss) - .min(ip_mtu - ip_repr.buffer_len() - repr.mss_header_len()); + .min(cx.caps.ip_mtu() - ip_repr.buffer_len() - TCP_HEADER_LEN); let offset = self.remote_last_seq - self.local_seq_no; repr.payload = self.tx_buffer.get_allocated(offset, size); @@ -1860,7 +1929,7 @@ impl<'a> TcpSocket<'a> { // sequence space will elicit an ACK, we only need to send an explicit packet if we // couldn't fill the sequence space with anything. let is_keep_alive; - if self.timer.should_keep_alive(timestamp) && repr.is_empty() { + if self.timer.should_keep_alive(cx.now) && repr.is_empty() { repr.seq_number = repr.seq_number - 1; repr.payload = b"\x00"; // RFC 1122 says we should do this is_keep_alive = true; @@ -1895,9 +1964,7 @@ impl<'a> TcpSocket<'a> { if repr.control == TcpControl::Syn { // Fill the MSS option. See RFC 6691 for an explanation of this calculation. - let mut max_segment_size = ip_mtu; - max_segment_size -= ip_repr.buffer_len(); - max_segment_size -= repr.mss_header_len(); + let max_segment_size = cx.caps.ip_mtu() - ip_repr.buffer_len() - TCP_HEADER_LEN; repr.max_seg_size = Some(max_segment_size as u16); } @@ -1913,7 +1980,7 @@ impl<'a> TcpSocket<'a> { // We've sent something, whether useful data or a keep-alive packet, so rewind // the keep-alive timer. - self.timer.rewind_keep_alive(timestamp, self.keep_alive); + self.timer.rewind_keep_alive(cx.now, self.keep_alive); // Reset delayed-ack timer if self.ack_delay_until.is_some() { @@ -1934,13 +2001,13 @@ impl<'a> TcpSocket<'a> { self.remote_last_win = repr.window_len; if repr.segment_len() > 0 { - self.rtte.on_send(timestamp, repr.seq_number + repr.segment_len()); + self.rtte.on_send(cx.now, repr.seq_number + repr.segment_len()); } - if !self.seq_to_transmit() && repr.segment_len() > 0 { + if !self.seq_to_transmit(cx) && repr.segment_len() > 0 { // If we've transmitted all data we could (and there was something at all, // data or flag, to transmit, not just an ACK), wind up the retransmit timer. - self.timer.set_for_retransmit(timestamp, self.rtte.retransmission_timeout()); + self.timer.set_for_retransmit(cx.now, self.rtte.retransmission_timeout()); } if self.state == State::Closed { @@ -1953,7 +2020,7 @@ impl<'a> TcpSocket<'a> { } #[allow(clippy::if_same_then_else)] - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, cx: &Context) -> PollAt { // The logic here mirrors the beginning of dispatch() closely. if !self.remote_endpoint.is_specified() { // No one to talk to, nothing to transmit. @@ -1964,7 +2031,7 @@ impl<'a> TcpSocket<'a> { } else if self.state == State::Closed { // Socket was aborted, we have an RST packet to transmit. PollAt::Now - } else if self.seq_to_transmit() { + } else if self.seq_to_transmit(cx) { // We have a data or flag packet to transmit. PollAt::Now } else { @@ -2059,9 +2126,9 @@ mod test { }; #[cfg(feature = "proto-ipv6")] - const BASE_MSS: u16 = 1460; + const BASE_MSS: u16 = 1440; #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - const BASE_MSS: u16 = 1480; + const BASE_MSS: u16 = 1460; // =========================================================================================// // Helper functions @@ -2079,7 +2146,10 @@ mod test { net_trace!("send: {}", repr); assert!(socket.accepts(&ip_repr, repr)); - match socket.process(timestamp, &ip_repr, repr) { + + let mut cx = Context::DUMMY.clone(); + cx.now = timestamp; + match socket.process(&cx, &ip_repr, repr) { Ok(Some((_ip_repr, repr))) => { net_trace!("recv: {}", repr); Ok(Some(repr)) @@ -2091,8 +2161,9 @@ mod test { fn recv(socket: &mut TcpSocket, timestamp: Instant, mut f: F) where F: FnMut(Result) { - let mtu = 1520; - let result = socket.dispatch(timestamp, mtu, |(ip_repr, tcp_repr)| { + let mut cx = Context::DUMMY.clone(); + cx.now = timestamp; + let result = socket.dispatch(&cx, |(ip_repr, tcp_repr)| { let ip_repr = ip_repr.lower(&[IpCidr::new(LOCAL_END.addr, 24)]).unwrap(); assert_eq!(ip_repr.protocol(), IpProtocol::Tcp); @@ -2216,8 +2287,11 @@ mod test { socket_syn_received_with_buffer_sizes(64, 64) } - fn socket_syn_sent() -> TcpSocket<'static> { - let mut s = socket(); + fn socket_syn_sent_with_buffer_sizes( + tx_len: usize, + rx_len: usize + ) -> TcpSocket<'static> { + let mut s = socket_with_buffer_sizes(tx_len, rx_len); s.state = State::SynSent; s.local_endpoint = IpEndpoint::new(MOCK_UNSPECIFIED, LOCAL_PORT); s.remote_endpoint = REMOTE_END; @@ -2226,6 +2300,10 @@ mod test { s } + fn socket_syn_sent() -> TcpSocket<'static> { + socket_syn_sent_with_buffer_sizes(64, 64) + } + fn socket_syn_sent_with_local_ipendpoint(local: IpEndpoint) -> TcpSocket<'static> { let mut s = socket(); s.state = State::SynSent; @@ -2431,7 +2509,7 @@ mod test { ack_number: Some(REMOTE_SEQ + 1), max_seg_size: Some(BASE_MSS), window_scale: Some(*shift_amt), - window_len: cmp::min(*buffer_size >> *shift_amt, 65535) as u16, + window_len: cmp::min(*buffer_size, 65535) as u16, ..RECV_TEMPL }]); } @@ -2603,6 +2681,7 @@ mod test { window_scale: None, ..SEND_TEMPL }); + assert_eq!(s.remote_win_shift, 0); assert_eq!(s.remote_win_scale, None); } @@ -2859,13 +2938,70 @@ mod test { ack_number: None, max_seg_size: Some(BASE_MSS), window_scale: Some(*shift_amt), - window_len: cmp::min(*buffer_size >> *shift_amt, 65535) as u16, + window_len: cmp::min(*buffer_size, 65535) as u16, sack_permitted: true, ..RECV_TEMPL }]); } } + #[test] + fn test_syn_sent_syn_ack_no_window_scaling() { + let mut s = socket_syn_sent_with_buffer_sizes(1048576, 1048576); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + // scaling does NOT apply to the window value in SYN packets + window_len: 65535, + window_scale: Some(5), + sack_permitted: true, + ..RECV_TEMPL + }]); + assert_eq!(s.remote_win_shift, 5); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: None, + window_len: 42, + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Established); + assert_eq!(s.remote_win_shift, 0); + assert_eq!(s.remote_win_scale, None); + assert_eq!(s.remote_win_len, 42); + } + + #[test] + fn test_syn_sent_syn_ack_window_scaling() { + let mut s = socket_syn_sent(); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(7), + window_len: 42, + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Established); + assert_eq!(s.remote_win_scale, Some(7)); + // scaling does NOT apply to the window value in SYN packets + assert_eq!(s.remote_win_len, 42); + } + // =========================================================================================// // Tests for the ESTABLISHED state. // =========================================================================================// @@ -3040,6 +3176,7 @@ mod test { #[test] fn test_established_send_no_ack_send() { let mut s = socket_established(); + s.set_nagle_enabled(false); s.send_slice(b"abcdef").unwrap(); recv!(s, [TcpRepr { seq_number: LOCAL_SEQ + 1, @@ -4811,7 +4948,7 @@ mod test { fn test_listen_timeout() { let mut s = socket_listen(); s.set_timeout(Some(Duration::from_millis(100))); - assert_eq!(s.poll_at(), PollAt::Ingress); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Ingress); } #[test] @@ -4830,7 +4967,7 @@ mod test { ..RECV_TEMPL })); assert_eq!(s.state, State::SynSent); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(250))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(250))); recv!(s, time 250, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1, @@ -4846,23 +4983,23 @@ mod test { let mut s = socket_established(); s.set_timeout(Some(Duration::from_millis(1000))); recv!(s, time 250, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(1250))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(1250))); s.send_slice(b"abcdef").unwrap(); - assert_eq!(s.poll_at(), PollAt::Now); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Now); recv!(s, time 255, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(955))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(955))); recv!(s, time 955, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(1255))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(1255))); recv!(s, time 1255, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 6, @@ -4884,13 +5021,13 @@ mod test { ..RECV_TEMPL })); recv!(s, time 100, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(150))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(150))); send!(s, time 105, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), ..SEND_TEMPL }); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(155))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(155))); recv!(s, time 155, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4898,7 +5035,7 @@ mod test { ..RECV_TEMPL })); recv!(s, time 155, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(205))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(205))); recv!(s, time 200, Err(Error::Exhausted)); recv!(s, time 205, Ok(TcpRepr { control: TcpControl::Rst, @@ -4954,14 +5091,14 @@ mod test { s.set_timeout(Some(Duration::from_millis(200))); s.remote_last_ts = Some(Instant::from_millis(100)); s.abort(); - assert_eq!(s.poll_at(), PollAt::Now); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Now); recv!(s, time 100, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Ingress); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Ingress); } // =========================================================================================// @@ -4988,7 +5125,7 @@ mod test { s.set_keep_alive(Some(Duration::from_millis(100))); // drain the forced keep-alive packet - assert_eq!(s.poll_at(), PollAt::Now); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Now); recv!(s, time 0, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4996,7 +5133,7 @@ mod test { ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(100))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(100))); recv!(s, time 95, Err(Error::Exhausted)); recv!(s, time 100, Ok(TcpRepr { seq_number: LOCAL_SEQ, @@ -5005,7 +5142,7 @@ mod test { ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(200))); recv!(s, time 195, Err(Error::Exhausted)); recv!(s, time 200, Ok(TcpRepr { seq_number: LOCAL_SEQ, @@ -5019,7 +5156,7 @@ mod test { ack_number: Some(LOCAL_SEQ + 1), ..SEND_TEMPL }); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(350))); + assert_eq!(s.poll_at(&Context::DUMMY), PollAt::Time(Instant::from_millis(350))); recv!(s, time 345, Err(Error::Exhausted)); recv!(s, time 350, Ok(TcpRepr { seq_number: LOCAL_SEQ, @@ -5036,10 +5173,9 @@ mod test { #[test] fn test_set_hop_limit() { let mut s = socket_syn_received(); - let mtu = 1520; s.set_hop_limit(Some(0x2a)); - assert_eq!(s.dispatch(Instant::from_millis(0), mtu, |(ip_repr, _)| { + assert_eq!(s.dispatch(&Context::DUMMY, |(ip_repr, _)| { assert_eq!(ip_repr.hop_limit(), 0x2a); Ok(()) }), Ok(())); @@ -5119,6 +5255,8 @@ mod test { #[test] fn test_buffer_wraparound_tx() { let mut s = socket_established(); + s.set_nagle_enabled(false); + s.tx_buffer = SocketBuffer::new(vec![b'.'; 9]); assert_eq!(s.send_slice(b"xxxyyy"), Ok(6)); assert_eq!(s.tx_buffer.dequeue_many(3), &b"xxx"[..]); @@ -5407,6 +5545,57 @@ mod test { })); } + // =========================================================================================// + // Tests for Nagle's Algorithm + // =========================================================================================// + + #[test] + fn test_nagle() { + let mut s = socket_established(); + s.remote_mss = 6; + + s.send_slice(b"abcdef").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + + // If there's data in flight, full segments get sent. + s.send_slice(b"foobar").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }]); + + s.send_slice(b"aaabbbccc").unwrap(); + // If there's data in flight, not-full segments don't get sent. + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"aaabbb"[..], + ..RECV_TEMPL + }]); + + // Data gets ACKd, so there's no longer data in flight + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6 + 6), + ..SEND_TEMPL + }); + + // Now non-full segment gets sent. + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ccc"[..], + ..RECV_TEMPL + }]); + } + // =========================================================================================// // Tests for packet filtering. // =========================================================================================// diff --git a/src/socket/udp.rs b/src/socket/udp.rs index f39b0dd..7503711 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -3,7 +3,7 @@ use core::cmp::min; use core::task::Waker; use crate::{Error, Result}; -use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt}; +use crate::socket::{Socket, SocketMeta, SocketHandle, PollAt, Context}; use crate::storage::{PacketBuffer, PacketMetadata}; use crate::wire::{IpProtocol, IpRepr, IpEndpoint, UdpRepr}; #[cfg(feature = "async")] @@ -293,7 +293,7 @@ impl<'a> UdpSocket<'a> { true } - pub(crate) fn process(&mut self, ip_repr: &IpRepr, repr: &UdpRepr, payload: &[u8]) -> Result<()> { + pub(crate) fn process(&mut self, _cx: &Context, ip_repr: &IpRepr, repr: &UdpRepr, payload: &[u8]) -> Result<()> { debug_assert!(self.accepts(ip_repr, repr)); let size = payload.len(); @@ -311,7 +311,7 @@ impl<'a> UdpSocket<'a> { Ok(()) } - pub(crate) fn dispatch(&mut self, emit: F) -> Result<()> + pub(crate) fn dispatch(&mut self, _cx: &Context, emit: F) -> Result<()> where F: FnOnce((IpRepr, UdpRepr, &[u8])) -> Result<()> { let handle = self.handle(); let endpoint = self.endpoint; @@ -342,7 +342,7 @@ impl<'a> UdpSocket<'a> { Ok(()) } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { if self.tx_buffer.is_empty() { PollAt::Ingress } else { @@ -465,14 +465,14 @@ mod test { assert_eq!(socket.bind(LOCAL_END), Ok(())); assert!(socket.can_send()); - assert_eq!(socket.dispatch(|_| unreachable!()), + assert_eq!(socket.dispatch(&Context::DUMMY, |_| unreachable!()), Err(Error::Exhausted)); assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(())); assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, udp_repr, payload)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, udp_repr, payload)| { assert_eq!(ip_repr, LOCAL_IP_REPR); assert_eq!(udp_repr, LOCAL_UDP_REPR); assert_eq!(payload, PAYLOAD); @@ -480,7 +480,7 @@ mod test { }), Err(Error::Unaddressable)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, udp_repr, payload)| { + assert_eq!(socket.dispatch(&Context::DUMMY, |(ip_repr, udp_repr, payload)| { assert_eq!(ip_repr, LOCAL_IP_REPR); assert_eq!(udp_repr, LOCAL_UDP_REPR); assert_eq!(payload, PAYLOAD); @@ -498,12 +498,12 @@ mod test { assert_eq!(socket.recv(), Err(Error::Exhausted)); assert!(socket.accepts(&remote_ip_repr(), &REMOTE_UDP_REPR)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), + assert_eq!(socket.process(&Context::DUMMY, &remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), Ok(())); assert!(socket.can_recv()); assert!(socket.accepts(&remote_ip_repr(), &REMOTE_UDP_REPR)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), + assert_eq!(socket.process(&Context::DUMMY, &remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), Err(Error::Exhausted)); assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END))); assert!(!socket.can_recv()); @@ -516,7 +516,7 @@ mod test { assert_eq!(socket.peek(), Err(Error::Exhausted)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), + assert_eq!(socket.process(&Context::DUMMY, &remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), Ok(())); assert_eq!(socket.peek(), Ok((&b"abcdef"[..], &REMOTE_END))); assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END))); @@ -529,7 +529,7 @@ mod test { assert_eq!(socket.bind(LOCAL_PORT), Ok(())); assert!(socket.accepts(&remote_ip_repr(), &REMOTE_UDP_REPR)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), + assert_eq!(socket.process(&Context::DUMMY, &remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), Ok(())); let mut slice = [0; 4]; @@ -542,7 +542,7 @@ mod test { let mut socket = socket(buffer(1), buffer(0)); assert_eq!(socket.bind(LOCAL_PORT), Ok(())); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), + assert_eq!(socket.process(&Context::DUMMY, &remote_ip_repr(), &REMOTE_UDP_REPR, PAYLOAD), Ok(())); let mut slice = [0; 4]; @@ -560,7 +560,7 @@ mod test { s.set_hop_limit(Some(0x2a)); assert_eq!(s.send_slice(b"abcdef", REMOTE_END), Ok(())); - assert_eq!(s.dispatch(|(ip_repr, _, _)| { + assert_eq!(s.dispatch(&Context::DUMMY, |(ip_repr, _, _)| { assert_eq!(ip_repr, IpRepr::Unspecified{ src_addr: MOCK_IP_ADDR_1, dst_addr: MOCK_IP_ADDR_2, @@ -637,7 +637,7 @@ mod test { src_port: REMOTE_PORT, dst_port: LOCAL_PORT, }; - assert_eq!(socket.process(&remote_ip_repr(), &repr, &[]), Ok(())); + assert_eq!(socket.process(&Context::DUMMY, &remote_ip_repr(), &repr, &[]), Ok(())); assert_eq!(socket.recv(), Ok((&[][..], REMOTE_END))); } diff --git a/src/wire/ipv4.rs b/src/wire/ipv4.rs index 687b4b9..221d656 100644 --- a/src/wire/ipv4.rs +++ b/src/wire/ipv4.rs @@ -263,6 +263,9 @@ mod field { pub const DST_ADDR: Field = 16..20; } +pub const HEADER_LEN: usize = field::DST_ADDR.end; + + impl> Packet { /// Imbue a raw octet buffer with IPv4 packet structure. pub fn new_unchecked(buffer: T) -> Packet { diff --git a/src/wire/ipv6.rs b/src/wire/ipv6.rs index 6f54a0c..575219c 100644 --- a/src/wire/ipv6.rs +++ b/src/wire/ipv6.rs @@ -380,6 +380,9 @@ mod field { pub const DST_ADDR: Field = 24..40; } +/// Length of an IPv6 header. +pub const HEADER_LEN: usize = field::DST_ADDR.end; + impl> Packet { /// Create a raw octet buffer with an IPv6 packet structure. #[inline] diff --git a/src/wire/mod.rs b/src/wire/mod.rs index 7fb37b5..7a7f39c 100644 --- a/src/wire/mod.rs +++ b/src/wire/mod.rs @@ -140,13 +140,15 @@ pub use self::ipv4::{Address as Ipv4Address, Packet as Ipv4Packet, Repr as Ipv4Repr, Cidr as Ipv4Cidr, - MIN_MTU as IPV4_MIN_MTU}; + HEADER_LEN as IPV4_HEADER_LEN, + MIN_MTU as IPV4_MIN_MTU}; #[cfg(feature = "proto-ipv6")] pub use self::ipv6::{Address as Ipv6Address, Packet as Ipv6Packet, Repr as Ipv6Repr, Cidr as Ipv6Cidr, + HEADER_LEN as IPV6_HEADER_LEN, MIN_MTU as IPV6_MIN_MTU}; #[cfg(feature = "proto-ipv6")] @@ -218,7 +220,8 @@ pub use self::tcp::{SeqNumber as TcpSeqNumber, Packet as TcpPacket, TcpOption, Repr as TcpRepr, - Control as TcpControl}; + Control as TcpControl, + HEADER_LEN as TCP_HEADER_LEN}; #[cfg(feature = "proto-dhcpv4")] pub use self::dhcpv4::{Packet as DhcpPacket, diff --git a/src/wire/tcp.rs b/src/wire/tcp.rs index c0fff41..d5622d9 100644 --- a/src/wire/tcp.rs +++ b/src/wire/tcp.rs @@ -109,6 +109,8 @@ mod field { pub const OPT_SACKRNG: u8 = 0x05; } +pub const HEADER_LEN: usize = field::URGENT.end; + impl> Packet { /// Imbue a raw octet buffer with TCP packet structure. pub fn new_unchecked(buffer: T) -> Packet { @@ -857,14 +859,6 @@ impl<'a> Repr<'a> { length } - /// Return the length of the header for the TCP protocol. - /// - /// Per RFC 6691, this should be used for MSS calculations. It may be smaller than the buffer - /// space required to accomodate this packet's data. - pub fn mss_header_len(&self) -> usize { - field::URGENT.end - } - /// Return the length of a packet that will be emitted from this high-level representation. pub fn buffer_len(&self) -> usize { self.header_len() + self.payload.len()