diff --git a/src/socket/dhcpv4.rs b/src/socket/dhcpv4.rs index 29f644f..1af3ab5 100644 --- a/src/socket/dhcpv4.rs +++ b/src/socket/dhcpv4.rs @@ -100,6 +100,8 @@ enum ClientState { } /// Return value for the `Dhcpv4Socket::poll` function +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Event<'a> { /// Configuration has been lost (for example, the lease has expired) Deconfigured, @@ -328,6 +330,16 @@ impl Dhcpv4Socket { Some((config, renew_at, expires_at)) } + #[cfg(not(test))] + fn random_transaction_id() -> u32 { + crate::rand::rand_u32() + } + + #[cfg(test)] + fn random_transaction_id() -> u32 { + 0x12345678 + } + pub(crate) fn dispatch(&mut self, cx: &Context, emit: F) -> Result<()> where F: FnOnce((Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<()>, @@ -342,7 +354,7 @@ impl Dhcpv4Socket { // We don't directly modify self.transaction_id because sending the packet // may fail. We only want to update state after succesfully sending. - let next_transaction_id = crate::rand::rand_u32(); + let next_transaction_id = Self::random_transaction_id(); let mut dhcp_repr = DhcpRepr { message_type: DhcpMessageType::Discover, @@ -504,3 +516,288 @@ impl<'a> From for Socket<'a> { Socket::Dhcpv4(val) } } + +#[cfg(test)] +mod test { + + use super::*; + use crate::wire::EthernetAddress; + + // =========================================================================================// + // Helper functions + + fn send( + socket: &mut Dhcpv4Socket, + timestamp: Instant, + (ip_repr, udp_repr, dhcp_repr): (Ipv4Repr, UdpRepr, DhcpRepr), + ) -> Result<()> { + net_trace!("send: {:?}", ip_repr); + net_trace!(" {:?}", udp_repr); + net_trace!(" {:?}", dhcp_repr); + + let mut payload = vec![0; dhcp_repr.buffer_len()]; + dhcp_repr + .emit(&mut DhcpPacket::new_unchecked(&mut payload)) + .unwrap(); + + let mut cx = Context::DUMMY.clone(); + cx.now = timestamp; + socket.process(&cx, &ip_repr, &udp_repr, &payload) + } + + fn recv(socket: &mut Dhcpv4Socket, timestamp: Instant, mut f: F) + where + F: FnMut(Result<(Ipv4Repr, UdpRepr, DhcpRepr)>), + { + let mut cx = Context::DUMMY.clone(); + cx.now = timestamp; + let result = socket.dispatch(&cx, |(mut ip_repr, udp_repr, dhcp_repr)| { + assert_eq!(ip_repr.protocol, IpProtocol::Udp); + assert_eq!( + ip_repr.payload_len, + udp_repr.header_len() + dhcp_repr.buffer_len() + ); + + // We validated the payload len, change it to 0 to make equality testing easier + ip_repr.payload_len = 0; + + net_trace!("recv: {:?}", ip_repr); + net_trace!(" {:?}", udp_repr); + net_trace!(" {:?}", dhcp_repr); + Ok(f(Ok((ip_repr, udp_repr, dhcp_repr)))) + }); + match result { + Ok(()) => (), + Err(e) => f(Err(e)), + } + } + + macro_rules! send { + ($socket:ident, $repr:expr) => + (send!($socket, time 0, $repr)); + ($socket:ident, $repr:expr, $result:expr) => + (send!($socket, time 0, $repr, $result)); + ($socket:ident, time $time:expr, $repr:expr) => + (send!($socket, time $time, $repr, Ok(( )))); + ($socket:ident, time $time:expr, $repr:expr, $result:expr) => + (assert_eq!(send(&mut $socket, Instant::from_millis($time), $repr), $result)); + } + + macro_rules! recv { + ($socket:ident, [$( $repr:expr ),*]) => ({ + $( recv!($socket, Ok($repr)); )* + recv!($socket, Err(Error::Exhausted)) + }); + ($socket:ident, time $time:expr, [$( $repr:expr ),*]) => ({ + $( recv!($socket, time $time, Ok($repr)); )* + recv!($socket, time $time, Err(Error::Exhausted)) + }); + ($socket:ident, $result:expr) => + (recv!($socket, time 0, $result)); + ($socket:ident, time $time:expr, $result:expr) => + (recv(&mut $socket, Instant::from_millis($time), |result| { + assert_eq!(result, $result) + })); + } + + #[cfg(feature = "log")] + fn init_logger() { + struct Logger; + static LOGGER: Logger = Logger; + + impl log::Log for Logger { + fn enabled(&self, _metadata: &log::Metadata) -> bool { + true + } + + fn log(&self, record: &log::Record) { + println!("{}", record.args()); + } + + fn flush(&self) {} + } + + // If it fails, that just means we've already set it to the same value. + let _ = log::set_logger(&LOGGER); + log::set_max_level(log::LevelFilter::Trace); + + println!(); + } + + // =========================================================================================// + // Constants + + const TXID: u32 = 0x12345678; + + const MY_IP: Ipv4Address = Ipv4Address([192, 168, 1, 42]); + const SERVER_IP: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + const DNS_IP_1: Ipv4Address = Ipv4Address([1, 1, 1, 1]); + const DNS_IP_2: Ipv4Address = Ipv4Address([1, 1, 1, 2]); + const DNS_IP_3: Ipv4Address = Ipv4Address([1, 1, 1, 3]); + const DNS_IPS: [Option; DHCP_MAX_DNS_SERVER_COUNT] = + [Some(DNS_IP_1), Some(DNS_IP_2), Some(DNS_IP_3)]; + const MASK_24: Ipv4Address = Ipv4Address([255, 255, 255, 0]); + + const MY_MAC: EthernetAddress = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + + const IP_BROADCAST: Ipv4Repr = Ipv4Repr { + src_addr: Ipv4Address::UNSPECIFIED, + dst_addr: Ipv4Address::BROADCAST, + protocol: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_RECV: Ipv4Repr = Ipv4Repr { + src_addr: SERVER_IP, + dst_addr: MY_IP, + protocol: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_SEND: Ipv4Repr = Ipv4Repr { + src_addr: MY_IP, + dst_addr: SERVER_IP, + protocol: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const UDP_SEND: UdpRepr = UdpRepr { + src_port: 68, + dst_port: 67, + }; + const UDP_RECV: UdpRepr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + + const DHCP_DEFAULT: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Unknown(99), + transaction_id: TXID, + client_hardware_address: MY_MAC, + client_ip: Ipv4Address::UNSPECIFIED, + your_ip: Ipv4Address::UNSPECIFIED, + server_ip: Ipv4Address::UNSPECIFIED, + router: None, + subnet_mask: None, + relay_agent_ip: Ipv4Address::UNSPECIFIED, + broadcast: false, + requested_ip: None, + client_identifier: None, + server_identifier: None, + parameter_request_list: None, + dns_servers: None, + max_size: None, + lease_duration: None, + }; + + const DHCP_DISCOVER: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Discover, + client_identifier: Some(MY_MAC), + parameter_request_list: Some(&[1, 3, 6]), + max_size: Some(1432), + ..DHCP_DEFAULT + }; + + const DHCP_OFFER: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Offer, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + + your_ip: MY_IP, + router: Some(SERVER_IP), + subnet_mask: Some(MASK_24), + dns_servers: Some(DNS_IPS), + lease_duration: Some(60), + + ..DHCP_DEFAULT + }; + + const DHCP_REQUEST: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + server_identifier: Some(SERVER_IP), + max_size: Some(1432), + + requested_ip: Some(MY_IP), + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + const DHCP_ACK: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Ack, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + + your_ip: MY_IP, + router: Some(SERVER_IP), + subnet_mask: Some(MASK_24), + dns_servers: Some(DNS_IPS), + lease_duration: Some(60), + + ..DHCP_DEFAULT + }; + + // =========================================================================================// + // Tests + + fn socket() -> Dhcpv4Socket { + #[cfg(feature = "log")] + init_logger(); + + let mut s = Dhcpv4Socket::new(); + assert_eq!(s.poll(), Some(Event::Deconfigured)); + s + } + + fn socket_bound() -> Dhcpv4Socket { + let mut s = socket(); + s.state = ClientState::Renewing(RenewState { + config: Config { + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: DNS_IPS, + router: Some(SERVER_IP), + }, + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + renew_at: Instant::from_secs(30), + expires_at: Instant::from_secs(60), + }); + + s + } + + #[test] + fn test_bind() { + let mut s = socket(); + + recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV, DHCP_OFFER)); + assert_eq!(s.poll(), None); + recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV, DHCP_ACK)); + + assert_eq!( + s.poll(), + Some(Event::Configured(&Config { + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: DNS_IPS, + router: Some(SERVER_IP), + })) + ); + + match s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(30)); + assert_eq!(r.expires_at, Instant::from_secs(60)); + } + _ => panic!("Invalid state"), + } + } +} diff --git a/src/socket/mod.rs b/src/socket/mod.rs index d1de47d..d28bb2d 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -216,7 +216,9 @@ impl Context { max_transmission_unit: 1500, }, #[cfg(all(feature = "medium-ethernet", feature = "socket-dhcpv4"))] - ethernet_address: None, + ethernet_address: Some(crate::wire::EthernetAddress([ + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + ])), now: Instant::from_millis_const(0), }; }