Factor out RawSocket::accepts.

v0.7.x
Egor Karavaev 2017-09-01 00:43:22 +03:00 committed by whitequark
parent 9b242c7099
commit b4d6a53e34
2 changed files with 32 additions and 8 deletions

View File

@ -289,13 +289,15 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
let mut handled_by_raw_socket = false;
for raw_socket in sockets.iter_mut().filter_map(
<Socket as AsSocket<RawSocket>>::try_as_socket) {
if !raw_socket.accepts(&ip_repr) { continue }
match raw_socket.process(&ip_repr, ip_payload) {
// The packet is valid and handled by socket.
Ok(()) => handled_by_raw_socket = true,
// The packet isn't addressed to the socket, or cannot be accepted by it.
Err(Error::Rejected) | Err(Error::Exhausted) => (),
// Raw sockets either accept or reject packets, not parse them.
_ => unreachable!(),
// The socket buffer is full.
Err(Error::Exhausted) => (),
// Raw sockets don't validate the packets in any way.
Err(_) => unreachable!(),
}
}

View File

@ -166,9 +166,15 @@ impl<'a, 'b> RawSocket<'a, 'b> {
Ok(length)
}
pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
if ip_repr.version() != self.ip_version { return false }
if ip_repr.protocol() != self.ip_protocol { return false }
true
}
pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<()> {
if ip_repr.version() != self.ip_version { return Err(Error::Rejected) }
if ip_repr.protocol() != self.ip_protocol { return Err(Error::Rejected) }
debug_assert!(self.accepts(ip_repr));
let header_len = ip_repr.buffer_len();
let total_len = header_len + payload.len();
@ -246,17 +252,18 @@ mod test {
fn socket(rx_buffer: SocketBuffer<'static, 'static>,
tx_buffer: SocketBuffer<'static, 'static>)
-> RawSocket<'static, 'static> {
match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(63),
match RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(IP_PROTO),
rx_buffer, tx_buffer) {
Socket::Raw(socket) => socket,
_ => unreachable!()
}
}
const IP_PROTO: u8 = 63;
const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
src_addr: Ipv4Address([10, 0, 0, 1]),
dst_addr: Ipv4Address([10, 0, 0, 2]),
protocol: IpProtocol::Unknown(63),
protocol: IpProtocol::Unknown(IP_PROTO),
payload_len: 4
});
const PACKET_BYTES: [u8; 24] = [
@ -332,10 +339,12 @@ mod test {
Ipv4Packet::new(&mut cksumd_packet).fill_checksum();
assert_eq!(socket.recv(), Err(Error::Exhausted));
assert!(socket.accepts(&HEADER_REPR));
assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
Ok(()));
assert!(socket.can_recv());
assert!(socket.accepts(&HEADER_REPR));
assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
Err(Error::Exhausted));
assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
@ -346,6 +355,7 @@ mod test {
fn test_recv_truncated_slice() {
let mut socket = socket(buffer(1), buffer(0));
assert!(socket.accepts(&HEADER_REPR));
assert_eq!(socket.process(&HEADER_REPR, &PACKET_PAYLOAD),
Ok(()));
@ -361,7 +371,19 @@ mod test {
let mut buffer = vec![0; 128];
buffer[..PACKET_BYTES.len()].copy_from_slice(&PACKET_BYTES[..]);
assert!(socket.accepts(&HEADER_REPR));
assert_eq!(socket.process(&HEADER_REPR, &buffer),
Err(Error::Truncated));
}
#[test]
fn test_doesnt_accept_wrong_proto() {
let socket = match RawSocket::new(IpVersion::Ipv4,
IpProtocol::Unknown(IP_PROTO+1),
buffer(1), buffer(1)) {
Socket::Raw(socket) => socket,
_ => unreachable!()
};
assert!(!socket.accepts(&HEADER_REPR));
}
}