Use the correct wrapping operations on TCP sequence numbers.

v0.7.x
whitequark 2016-12-27 18:34:13 +00:00
parent bbff907e87
commit 6b592742fd
3 changed files with 98 additions and 52 deletions

View File

@ -3,7 +3,7 @@ use core::fmt;
use Error;
use Managed;
use wire::{IpProtocol, IpAddress, IpEndpoint};
use wire::{TcpPacket, TcpRepr, TcpControl};
use wire::{TcpSeqNumber, TcpPacket, TcpRepr, TcpControl};
use socket::{Socket, IpRepr, IpPayload};
/// A TCP stream ring buffer.
@ -185,16 +185,16 @@ pub struct TcpSocket<'a> {
remote_endpoint: IpEndpoint,
/// The sequence number corresponding to the beginning of the transmit buffer.
/// I.e. an ACK(local_seq_no+n) packet removes n bytes from the transmit buffer.
local_seq_no: i32,
local_seq_no: TcpSeqNumber,
/// The sequence number corresponding to the beginning of the receive buffer.
/// I.e. userspace reading n bytes adds n to remote_seq_no.
remote_seq_no: i32,
remote_seq_no: TcpSeqNumber,
/// The last sequence number sent.
/// I.e. in an idle socket, local_seq_no+tx_buffer.len().
remote_last_seq: i32,
remote_last_seq: TcpSeqNumber,
/// The last acknowledgement number sent.
/// I.e. in an idle socket, remote_seq_no+rx_buffer.len().
remote_last_ack: i32,
remote_last_ack: TcpSeqNumber,
/// The speculative remote window size.
/// I.e. the actual remote window size minus the count of in-flight octets.
remote_win_len: usize,
@ -218,10 +218,10 @@ impl<'a> TcpSocket<'a> {
listen_address: IpAddress::default(),
local_endpoint: IpEndpoint::default(),
remote_endpoint: IpEndpoint::default(),
local_seq_no: 0,
remote_seq_no: 0,
remote_last_seq: 0,
remote_last_ack: 0,
local_seq_no: TcpSeqNumber(0),
remote_seq_no: TcpSeqNumber(0),
remote_last_seq: TcpSeqNumber(0),
remote_last_ack: TcpSeqNumber(0),
remote_win_len: 0,
retransmit: Retransmit::new(),
tx_buffer: tx_buffer.into(),
@ -341,7 +341,7 @@ impl<'a> TcpSocket<'a> {
if !self.can_recv() { return Err(()) }
let buffer = self.rx_buffer.dequeue(size);
self.remote_seq_no += buffer.len() as i32;
self.remote_seq_no += buffer.len();
if buffer.len() > 0 {
net_trace!("tcp:{}:{}: rx buffer: dequeueing {} octets",
self.local_endpoint, self.remote_endpoint, buffer.len());
@ -450,9 +450,9 @@ impl<'a> TcpSocket<'a> {
// all of the control flags we sent.
_ => 0
};
let unacknowledged = self.tx_buffer.len() as i32 + control_len;
if !(ack_number - self.local_seq_no >= 0 &&
ack_number - (self.local_seq_no + unacknowledged) <= 0) {
let unacknowledged = self.tx_buffer.len() + control_len;
if !(ack_number >= self.local_seq_no &&
ack_number <= (self.local_seq_no + unacknowledged)) {
net_trace!("tcp:{}:{}: unacceptable ACK ({} not in {}..{})",
self.local_endpoint, self.remote_endpoint,
ack_number, self.local_seq_no, self.local_seq_no + unacknowledged);
@ -468,13 +468,13 @@ impl<'a> TcpSocket<'a> {
// In all other states, segments must occupy a valid portion of the receive window.
// For now, do not try to reassemble out-of-order segments.
(_, TcpRepr { seq_number, .. }) => {
let next_remote_seq = self.remote_seq_no + self.rx_buffer.len() as i32;
if seq_number - next_remote_seq > 0 {
let next_remote_seq = self.remote_seq_no + self.rx_buffer.len();
if seq_number > next_remote_seq {
net_trace!("tcp:{}:{}: unacceptable SEQ ({} not in {}..)",
self.local_endpoint, self.remote_endpoint,
seq_number, next_remote_seq);
return Err(Error::Malformed)
} else if seq_number - next_remote_seq != 0 {
} else if seq_number != next_remote_seq {
net_trace!("tcp:{}:{}: duplicate SEQ ({} in ..{})",
self.local_endpoint, self.remote_endpoint,
seq_number, next_remote_seq);
@ -511,7 +511,8 @@ impl<'a> TcpSocket<'a> {
}) => {
self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), dst_port);
self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port);
self.local_seq_no = -seq_number; // FIXME: use something more secure
// FIXME: use something more secure here
self.local_seq_no = TcpSeqNumber(-seq_number.0);
self.remote_last_seq = self.local_seq_no + 1;
self.remote_seq_no = seq_number + 1;
self.set_state(State::SynReceived);
@ -607,7 +608,7 @@ impl<'a> TcpSocket<'a> {
// 1. the retransmit timer has expired, or...
let mut may_send = self.retransmit.check();
// 2. we've got new data in the transmit buffer.
let remote_next_seq = self.local_seq_no + self.tx_buffer.len() as i32;
let remote_next_seq = self.local_seq_no + self.tx_buffer.len();
if self.remote_last_seq != remote_next_seq {
may_send = true;
}
@ -627,9 +628,9 @@ impl<'a> TcpSocket<'a> {
repr.payload = data;
// Speculatively shrink the remote window. This will get updated the next
// time we receive a packet.
self.remote_win_len -= data.len();
self.remote_win_len -= data.len();
// Advance the in-flight sequence number.
self.remote_last_seq += data.len() as i32;
self.remote_last_seq += data.len();
should_send = true;
}
}
@ -637,7 +638,7 @@ impl<'a> TcpSocket<'a> {
_ => unreachable!()
}
let ack_number = self.remote_seq_no + self.rx_buffer.len() as i32;
let ack_number = self.remote_seq_no + self.rx_buffer.len();
if !should_send && self.remote_last_ack != ack_number {
// Acknowledge all data we have received, since it is all in order.
net_trace!("tcp:{}:{}: sending ACK",
@ -692,25 +693,25 @@ mod test {
buffer.enqueue_slice(&b"bazhoge"[..]); // zhobarba
}
const LOCAL_IP: IpAddress = IpAddress::v4(10, 0, 0, 1);
const REMOTE_IP: IpAddress = IpAddress::v4(10, 0, 0, 2);
const LOCAL_PORT: u16 = 80;
const REMOTE_PORT: u16 = 49500;
const LOCAL_END: IpEndpoint = IpEndpoint::new(LOCAL_IP, LOCAL_PORT);
const REMOTE_END: IpEndpoint = IpEndpoint::new(REMOTE_IP, REMOTE_PORT);
const LOCAL_SEQ: i32 = 10000;
const REMOTE_SEQ: i32 = -10000;
const LOCAL_IP: IpAddress = IpAddress::v4(10, 0, 0, 1);
const REMOTE_IP: IpAddress = IpAddress::v4(10, 0, 0, 2);
const LOCAL_PORT: u16 = 80;
const REMOTE_PORT: u16 = 49500;
const LOCAL_END: IpEndpoint = IpEndpoint::new(LOCAL_IP, LOCAL_PORT);
const REMOTE_END: IpEndpoint = IpEndpoint::new(REMOTE_IP, REMOTE_PORT);
const LOCAL_SEQ: TcpSeqNumber = TcpSeqNumber(10000);
const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10000);
const SEND_TEMPL: TcpRepr<'static> = TcpRepr {
src_port: REMOTE_PORT, dst_port: LOCAL_PORT,
control: TcpControl::None,
seq_number: 0, ack_number: Some(0),
seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)),
window_len: 256, payload: &[]
};
const RECV_TEMPL: TcpRepr<'static> = TcpRepr {
src_port: LOCAL_PORT, dst_port: REMOTE_PORT,
control: TcpControl::None,
seq_number: 0, ack_number: Some(0),
seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)),
window_len: 64, payload: &[]
};
@ -917,7 +918,7 @@ mod test {
send!(s, TcpRepr {
control: TcpControl::Rst,
seq_number: REMOTE_SEQ,
ack_number: Some(1234),
ack_number: Some(TcpSeqNumber(1234)),
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.state, State::SynSent);
@ -1005,7 +1006,7 @@ mod test {
// Already acknowledged data.
send!(s, TcpRepr {
seq_number: REMOTE_SEQ + 1,
ack_number: Some(LOCAL_SEQ - 1),
ack_number: Some(TcpSeqNumber(LOCAL_SEQ.0 - 1)),
..SEND_TEMPL
}, Err(Error::Malformed));
assert_eq!(s.local_seq_no, LOCAL_SEQ + 1);

View File

@ -119,6 +119,7 @@ pub use self::icmpv4::Repr as Icmpv4Repr;
pub use self::udp::Packet as UdpPacket;
pub use self::udp::Repr as UdpRepr;
pub use self::tcp::SeqNumber as TcpSeqNumber;
pub use self::tcp::Packet as TcpPacket;
pub use self::tcp::Repr as TcpRepr;
pub use self::tcp::Control as TcpControl;

View File

@ -1,10 +1,54 @@
use core::fmt;
use core::{i32, ops, cmp, fmt};
use byteorder::{ByteOrder, NetworkEndian};
use Error;
use super::{IpProtocol, IpAddress};
use super::ip::checksum;
/// A TCP sequence number.
///
/// A sequence number is a monotonically advancing integer modulo 2<sup>32</sup>.
/// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct SeqNumber(pub i32);
impl fmt::Display for SeqNumber {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0 as u32)
}
}
impl ops::Add<usize> for SeqNumber {
type Output = SeqNumber;
fn add(self, rhs: usize) -> SeqNumber {
if rhs > i32::MAX as usize {
panic!("attempt to add to sequence number with unsigned overflow")
}
SeqNumber(self.0.wrapping_add(rhs as i32))
}
}
impl ops::AddAssign<usize> for SeqNumber {
fn add_assign(&mut self, rhs: usize) {
*self = *self + rhs;
}
}
impl ops::Sub for SeqNumber {
type Output = usize;
fn sub(self, rhs: SeqNumber) -> usize {
(self.0 - rhs.0) as usize
}
}
impl cmp::PartialOrd for SeqNumber {
fn partial_cmp(&self, other: &SeqNumber) -> Option<cmp::Ordering> {
(self.0 - other.0).partial_cmp(&0)
}
}
/// A read/write wrapper around an Transmission Control Protocol packet buffer.
#[derive(Debug)]
pub struct Packet<T: AsRef<[u8]>> {
@ -69,16 +113,16 @@ impl<T: AsRef<[u8]>> Packet<T> {
/// Return the sequence number field.
#[inline(always)]
pub fn seq_number(&self) -> i32 {
pub fn seq_number(&self) -> SeqNumber {
let data = self.buffer.as_ref();
NetworkEndian::read_i32(&data[field::SEQ_NUM])
SeqNumber(NetworkEndian::read_i32(&data[field::SEQ_NUM]))
}
/// Return the acknowledgement number field.
#[inline(always)]
pub fn ack_number(&self) -> i32 {
pub fn ack_number(&self) -> SeqNumber {
let data = self.buffer.as_ref();
NetworkEndian::read_i32(&data[field::ACK_NUM])
SeqNumber(NetworkEndian::read_i32(&data[field::ACK_NUM]))
}
/// Return the FIN flag.
@ -184,12 +228,12 @@ impl<T: AsRef<[u8]>> Packet<T> {
/// Return the length of the segment, in terms of sequence space.
#[inline(always)]
pub fn segment_len(&self) -> i32 {
pub fn segment_len(&self) -> usize {
let data = self.buffer.as_ref();
let mut length = data.len() - self.header_len() as usize;
if self.syn() { length += 1 }
if self.fin() { length += 1 }
length as i32
length
}
/// Validate the packet checksum.
@ -234,16 +278,16 @@ impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
/// Set the sequence number field.
#[inline(always)]
pub fn set_seq_number(&mut self, value: i32) {
pub fn set_seq_number(&mut self, value: SeqNumber) {
let mut data = self.buffer.as_mut();
NetworkEndian::write_i32(&mut data[field::SEQ_NUM], value)
NetworkEndian::write_i32(&mut data[field::SEQ_NUM], value.0)
}
/// Set the acknowledgement number field.
#[inline(always)]
pub fn set_ack_number(&mut self, value: i32) {
pub fn set_ack_number(&mut self, value: SeqNumber) {
let mut data = self.buffer.as_mut();
NetworkEndian::write_i32(&mut data[field::ACK_NUM], value)
NetworkEndian::write_i32(&mut data[field::ACK_NUM], value.0)
}
/// Clear the entire flags field.
@ -422,8 +466,8 @@ pub struct Repr<'a> {
pub src_port: u16,
pub dst_port: u16,
pub control: Control,
pub seq_number: i32,
pub ack_number: Option<i32>,
pub seq_number: SeqNumber,
pub ack_number: Option<SeqNumber>,
pub window_len: u16,
pub payload: &'a [u8]
}
@ -482,7 +526,7 @@ impl<'a> Repr<'a> {
packet.set_src_port(self.src_port);
packet.set_dst_port(self.dst_port);
packet.set_seq_number(self.seq_number);
packet.set_ack_number(self.ack_number.unwrap_or(0));
packet.set_ack_number(self.ack_number.unwrap_or(SeqNumber(0)));
packet.set_window_len(self.window_len);
packet.set_header_len(field::URGENT.end as u8);
packet.clear_flags();
@ -579,8 +623,8 @@ mod test {
let packet = Packet::new(&PACKET_BYTES[..]).unwrap();
assert_eq!(packet.src_port(), 48896);
assert_eq!(packet.dst_port(), 80);
assert_eq!(packet.seq_number(), 0x01234567);
assert_eq!(packet.ack_number(), 0x89abcdefu32 as i32);
assert_eq!(packet.seq_number(), SeqNumber(0x01234567));
assert_eq!(packet.ack_number(), SeqNumber(0x89abcdefu32 as i32));
assert_eq!(packet.header_len(), 20);
assert_eq!(packet.fin(), true);
assert_eq!(packet.syn(), false);
@ -601,8 +645,8 @@ mod test {
let mut packet = Packet::new(&mut bytes).unwrap();
packet.set_src_port(48896);
packet.set_dst_port(80);
packet.set_seq_number(0x01234567);
packet.set_ack_number(0x89abcdefu32 as i32);
packet.set_seq_number(SeqNumber(0x01234567));
packet.set_ack_number(SeqNumber(0x89abcdefu32 as i32));
packet.set_header_len(20);
packet.set_fin(true);
packet.set_syn(false);
@ -630,7 +674,7 @@ mod test {
Repr {
src_port: 48896,
dst_port: 80,
seq_number: 0x01234567,
seq_number: SeqNumber(0x01234567),
ack_number: None,
window_len: 0x0123,
control: Control::Syn,