Factor out TcpSocket::accepts.

This commit is contained in:
Egor Karavaev 2017-09-01 00:44:41 +03:00 committed by whitequark
parent 02b699e18c
commit 8404fe908c
2 changed files with 86 additions and 12 deletions

View File

@ -412,13 +412,13 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> {
for tcp_socket in sockets.iter_mut().filter_map(
<Socket as AsSocket<TcpSocket>>::try_as_socket) {
if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue }
match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) {
// The packet is valid and handled by socket.
Ok(reply) => return Ok(reply.map_or(Packet::None, Packet::Tcp)),
// The packet isn't addressed to the socket.
// Send RST only if no other socket accepts the packet.
Err(Error::Rejected) => continue,
// The packet is malformed, or addressed to the socket but cannot be accepted.
// The packet is malformed, or doesn't match the socket state,
// or the socket buffer is full.
Err(e) => return Err(e)
}
}

View File

@ -666,25 +666,31 @@ impl<'a> TcpSocket<'a> {
(ip_reply_repr, reply_repr)
}
pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) ->
Result<Option<(IpRepr, TcpRepr<'static>)>> {
if self.state == State::Closed { return Err(Error::Rejected) }
pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &TcpRepr) -> bool {
if self.state == State::Closed { return false }
// If we're still listening for SYNs and the packet has an ACK, it cannot
// be destined to this socket, but another one may well listen on the same
// local endpoint.
if self.state == State::Listen && repr.ack_number.is_some() { return Err(Error::Rejected) }
if self.state == State::Listen && repr.ack_number.is_some() { return false }
// Reject packets with a wrong destination.
if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) }
if self.local_endpoint.port != repr.dst_port { return false }
if !self.local_endpoint.addr.is_unspecified() &&
self.local_endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
self.local_endpoint.addr != ip_repr.dst_addr() { return false }
// Reject packets from a source to which we aren't connected.
if self.remote_endpoint.port != 0 &&
self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) }
self.remote_endpoint.port != repr.src_port { return false }
if !self.remote_endpoint.addr.is_unspecified() &&
self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) }
self.remote_endpoint.addr != ip_repr.src_addr() { return false }
true
}
pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) ->
Result<Option<(IpRepr, TcpRepr<'static>)>> {
debug_assert!(self.accepts(ip_repr, repr));
// Consider how much the sequence number space differs from the transmit buffer space.
let (sent_syn, sent_fin) = match self.state {
@ -1241,6 +1247,7 @@ mod test {
const LOCAL_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
const REMOTE_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
const THIRD_PARTY_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 3]));
const LOCAL_PORT: u16 = 80;
const REMOTE_PORT: u16 = 49500;
const LOCAL_END: IpEndpoint = IpEndpoint { addr: LOCAL_IP, port: LOCAL_PORT };
@ -1272,6 +1279,11 @@ mod test {
protocol: IpProtocol::Tcp,
payload_len: repr.buffer_len()
};
if !socket.accepts(&ip_repr, repr) {
return Err(Error::Rejected);
}
match socket.process(timestamp, &ip_repr, repr) {
Ok(Some((_ip_repr, repr))) => {
trace!("recv: {}", repr);
@ -2930,4 +2942,66 @@ mod test {
..RECV_TEMPL
}]);
}
// =========================================================================================//
// Tests for packet filtering
// =========================================================================================//
#[test]
fn test_doesnt_accept_wrong_port() {
let mut s = socket_established();
s.rx_buffer = SocketBuffer::new(vec![0; 6]);
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
payload: &b"abcdef"[..],
dst_port: LOCAL_PORT + 1,
..SEND_TEMPL
}, Err(Error::Rejected));
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
payload: &b"abcdef"[..],
src_port: REMOTE_PORT + 1,
..SEND_TEMPL
}, Err(Error::Rejected));
}
#[test]
fn test_doesnt_accept_wrong_ip() {
let s = socket_established();
let tcp_repr = TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ + 1),
payload: &b"abcdef"[..],
..SEND_TEMPL
};
let ip_repr = IpRepr::Unspecified {
src_addr: REMOTE_IP,
dst_addr: LOCAL_IP,
protocol: IpProtocol::Tcp,
payload_len: tcp_repr.buffer_len()
};
assert!(s.accepts(&ip_repr, &tcp_repr));
let ip_repr_wrong_src = IpRepr::Unspecified {
src_addr: THIRD_PARTY_IP,
dst_addr: LOCAL_IP,
protocol: IpProtocol::Tcp,
payload_len: tcp_repr.buffer_len()
};
assert!(!s.accepts(&ip_repr_wrong_src, &tcp_repr));
let ip_repr_wrong_dst = IpRepr::Unspecified {
src_addr: REMOTE_IP,
dst_addr: THIRD_PARTY_IP,
protocol: IpProtocol::Tcp,
payload_len: tcp_repr.buffer_len()
};
assert!(!s.accepts(&ip_repr_wrong_dst, &tcp_repr));
}
}