diff --git a/src/iface/ethernet.rs b/src/iface/ethernet.rs index 1b84cd7..85192b4 100644 --- a/src/iface/ethernet.rs +++ b/src/iface/ethernet.rs @@ -412,13 +412,13 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { for tcp_socket in sockets.iter_mut().filter_map( >::try_as_socket) { + if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue } + match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) { // The packet is valid and handled by socket. Ok(reply) => return Ok(reply.map_or(Packet::None, Packet::Tcp)), - // The packet isn't addressed to the socket. - // Send RST only if no other socket accepts the packet. - Err(Error::Rejected) => continue, - // The packet is malformed, or addressed to the socket but cannot be accepted. + // The packet is malformed, or doesn't match the socket state, + // or the socket buffer is full. Err(e) => return Err(e) } } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index e5380a1..23dc07e 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -666,25 +666,31 @@ impl<'a> TcpSocket<'a> { (ip_reply_repr, reply_repr) } - pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) -> - Result)>> { - if self.state == State::Closed { return Err(Error::Rejected) } + pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &TcpRepr) -> bool { + if self.state == State::Closed { return false } // If we're still listening for SYNs and the packet has an ACK, it cannot // be destined to this socket, but another one may well listen on the same // local endpoint. - if self.state == State::Listen && repr.ack_number.is_some() { return Err(Error::Rejected) } + if self.state == State::Listen && repr.ack_number.is_some() { return false } // Reject packets with a wrong destination. - if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) } + if self.local_endpoint.port != repr.dst_port { return false } if !self.local_endpoint.addr.is_unspecified() && - self.local_endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) } + self.local_endpoint.addr != ip_repr.dst_addr() { return false } // Reject packets from a source to which we aren't connected. if self.remote_endpoint.port != 0 && - self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) } + self.remote_endpoint.port != repr.src_port { return false } if !self.remote_endpoint.addr.is_unspecified() && - self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) } + self.remote_endpoint.addr != ip_repr.src_addr() { return false } + + true + } + + pub(crate) fn process(&mut self, timestamp: u64, ip_repr: &IpRepr, repr: &TcpRepr) -> + Result)>> { + debug_assert!(self.accepts(ip_repr, repr)); // Consider how much the sequence number space differs from the transmit buffer space. let (sent_syn, sent_fin) = match self.state { @@ -1241,6 +1247,7 @@ mod test { const LOCAL_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1])); const REMOTE_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2])); + const THIRD_PARTY_IP: IpAddress = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 3])); const LOCAL_PORT: u16 = 80; const REMOTE_PORT: u16 = 49500; const LOCAL_END: IpEndpoint = IpEndpoint { addr: LOCAL_IP, port: LOCAL_PORT }; @@ -1272,6 +1279,11 @@ mod test { protocol: IpProtocol::Tcp, payload_len: repr.buffer_len() }; + + if !socket.accepts(&ip_repr, repr) { + return Err(Error::Rejected); + } + match socket.process(timestamp, &ip_repr, repr) { Ok(Some((_ip_repr, repr))) => { trace!("recv: {}", repr); @@ -2930,4 +2942,66 @@ mod test { ..RECV_TEMPL }]); } + + // =========================================================================================// + // Tests for packet filtering + // =========================================================================================// + + #[test] + fn test_doesnt_accept_wrong_port() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + dst_port: LOCAL_PORT + 1, + ..SEND_TEMPL + }, Err(Error::Rejected)); + + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + src_port: REMOTE_PORT + 1, + ..SEND_TEMPL + }, Err(Error::Rejected)); + } + + #[test] + fn test_doesnt_accept_wrong_ip() { + let s = socket_established(); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }; + + let ip_repr = IpRepr::Unspecified { + src_addr: REMOTE_IP, + dst_addr: LOCAL_IP, + protocol: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len() + }; + assert!(s.accepts(&ip_repr, &tcp_repr)); + + let ip_repr_wrong_src = IpRepr::Unspecified { + src_addr: THIRD_PARTY_IP, + dst_addr: LOCAL_IP, + protocol: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len() + }; + assert!(!s.accepts(&ip_repr_wrong_src, &tcp_repr)); + + let ip_repr_wrong_dst = IpRepr::Unspecified { + src_addr: REMOTE_IP, + dst_addr: THIRD_PARTY_IP, + protocol: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len() + }; + assert!(!s.accepts(&ip_repr_wrong_dst, &tcp_repr)); + } }