diff --git a/src/wire/ipv4.rs b/src/wire/ipv4.rs index cabe483..9505b15 100644 --- a/src/wire/ipv4.rs +++ b/src/wire/ipv4.rs @@ -119,19 +119,21 @@ impl> Packet { /// Ensure that no accessor method will panic if called. /// Returns `Err(Error::Truncated)` if the buffer is too short. /// - /// The result of this check is invalidated by calling [set_header_len]. + /// The result of this check is invalidated by calling [set_header_len] + /// and [set_total_len]. /// /// [set_header_len]: #method.set_header_len + /// [set_total_len]: #method.set_total_len pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::DST_ADDR.end { Err(Error::Truncated) + } else if len < self.header_len() as usize { + Err(Error::Truncated) + } else if len < self.total_len() as usize { + Err(Error::Truncated) } else { - if len < self.header_len() as usize { - Err(Error::Truncated) - } else { - Ok(()) - } + Ok(()) } } @@ -634,6 +636,16 @@ mod test { PAYLOAD_BYTES.len()); } + #[test] + fn test_total_len_overflow() { + let mut bytes = vec![]; + bytes.extend(&PACKET_BYTES[..]); + Packet::new(&mut bytes).set_total_len(128); + + assert_eq!(Packet::new_checked(&bytes).unwrap_err(), + Error::Truncated); + } + static REPR_PACKET_BYTES: [u8; 24] = [0x45, 0x00, 0x00, 0x18, 0x00, 0x00, 0x40, 0x00,