From ea78053dc1bc57cf5dd4de96e50ad6034fd99daa Mon Sep 17 00:00:00 2001 From: whitequark Date: Mon, 26 Dec 2016 11:20:20 +0000 Subject: [PATCH] Factor out IpRepr into the wire module. --- src/iface/ethernet.rs | 66 +++++----------------- src/socket/mod.rs | 22 ++------ src/socket/tcp.rs | 63 ++++++++++----------- src/socket/udp.rs | 37 ++++++------ src/wire/ip.rs | 128 +++++++++++++++++++++++++++++++++++++++++- src/wire/mod.rs | 1 + 6 files changed, 195 insertions(+), 122 deletions(-) diff --git a/src/iface/ethernet.rs b/src/iface/ethernet.rs index b417f6a..bf8c459 100644 --- a/src/iface/ethernet.rs +++ b/src/iface/ethernet.rs @@ -5,11 +5,11 @@ use Error; use phy::Device; use wire::{EthernetAddress, EthernetProtocol, EthernetFrame}; use wire::{ArpPacket, ArpRepr, ArpOperation}; -use wire::{IpAddress, IpProtocol}; -use wire::{Ipv4Address, Ipv4Packet, Ipv4Repr}; +use wire::{Ipv4Packet, Ipv4Repr}; use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable}; +use wire::{IpAddress, IpProtocol, IpRepr}; use wire::{TcpPacket, TcpRepr, TcpControl}; -use socket::{Socket, IpRepr}; +use socket::Socket; use super::{ArpCache}; /// An Ethernet network interface. @@ -213,13 +213,8 @@ impl<'a, 'b: 'a, Ipv4Repr { src_addr, dst_addr, protocol } => { let mut handled = false; for socket in self.sockets.borrow_mut() { - let ip_repr = IpRepr { - src_addr: src_addr.into(), - dst_addr: dst_addr.into(), - protocol: protocol, - payload: ipv4_packet.payload() - }; - match socket.collect(&ip_repr) { + let ip_repr = IpRepr::Ipv4(ipv4_repr); + match socket.collect(&ip_repr, ipv4_packet.payload()) { Ok(()) => { // The packet was valid and handled by socket. handled = true; @@ -370,62 +365,27 @@ impl<'a, 'b: 'a, let mut nothing_to_transmit = true; for socket in self.sockets.borrow_mut() { - let result = socket.dispatch(&mut |repr| { - let src_addr = - try!(match &repr.src_addr { - &IpAddress::Unspecified | - &IpAddress::Ipv4(Ipv4Address([0, _, _, _])) => { - let mut assigned_addr = None; - for addr in src_protocol_addrs { - match addr { - addr @ &IpAddress::Ipv4(_) => { - assigned_addr = Some(addr); - break - } - _ => () - } - } - assigned_addr.ok_or(Error::Unaddressable) - }, - addr => Ok(addr) - }); - - let ipv4_repr = - match (src_addr, &repr.dst_addr) { - (&IpAddress::Ipv4(src_addr), - &IpAddress::Ipv4(dst_addr)) => { - Ipv4Repr { - src_addr: src_addr, - dst_addr: dst_addr, - protocol: repr.protocol - } - }, - _ => unreachable!() - }; + let result = socket.dispatch(&mut |repr, payload| { + let repr = try!(repr.lower(src_protocol_addrs)); let dst_hardware_addr = - match arp_cache.lookup(&repr.dst_addr) { + match arp_cache.lookup(&repr.dst_addr()) { None => return Err(Error::Unaddressable), Some(hardware_addr) => hardware_addr }; - let tx_len = EthernetFrame::<&[u8]>::buffer_len(ipv4_repr.buffer_len() + - repr.payload.buffer_len()); + let tx_len = EthernetFrame::<&[u8]>::buffer_len(repr.buffer_len() + + payload.buffer_len()); let mut tx_buffer = try!(device.transmit(tx_len)); let mut frame = try!(EthernetFrame::new(&mut tx_buffer)); frame.set_src_addr(src_hardware_addr); frame.set_dst_addr(dst_hardware_addr); frame.set_ethertype(EthernetProtocol::Ipv4); - let mut ip_packet = try!(Ipv4Packet::new(frame.payload_mut())); - ipv4_repr.emit(&mut ip_packet, repr.payload.buffer_len()); + repr.emit(frame.payload_mut(), payload.buffer_len()); - repr.payload.emit(&mut IpRepr { - src_addr: repr.src_addr, - dst_addr: repr.dst_addr, - protocol: repr.protocol, - payload: ip_packet.payload_mut() - }); + let mut ip_packet = try!(Ipv4Packet::new(frame.payload_mut())); + payload.emit(&repr, ip_packet.payload_mut()); Ok(()) }); diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 864b528..db3bd47 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -11,7 +11,7 @@ //! size for a buffer, allocate it, and let the networking stack use it. use Error; -use wire::{IpAddress, IpProtocol}; +use wire::IpRepr; mod udp; mod tcp; @@ -52,12 +52,12 @@ impl<'a, 'b> Socket<'a, 'b> { /// is returned. /// /// This function is used internally by the networking stack. - pub fn collect(&mut self, repr: &IpRepr<&[u8]>) -> Result<(), Error> { + pub fn collect(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> { match self { &mut Socket::Udp(ref mut socket) => - socket.collect(repr), + socket.collect(ip_repr, payload), &mut Socket::Tcp(ref mut socket) => - socket.collect(repr), + socket.collect(ip_repr, payload), &mut Socket::__Nonexhaustive => unreachable!() } } @@ -70,7 +70,7 @@ impl<'a, 'b> Socket<'a, 'b> { /// /// This function is used internally by the networking stack. pub fn dispatch(&mut self, emit: &mut F) -> Result<(), Error> - where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> { + where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> { match self { &mut Socket::Udp(ref mut socket) => socket.dispatch(emit), @@ -81,16 +81,6 @@ impl<'a, 'b> Socket<'a, 'b> { } } -/// An IP packet representation. -/// -/// This struct abstracts the various versions of IP packets. -pub struct IpRepr { - pub src_addr: IpAddress, - pub dst_addr: IpAddress, - pub protocol: IpProtocol, - pub payload: T -} - /// An IP-encapsulated packet representation. /// /// This trait abstracts the various types of packets layered under the IP protocol, @@ -100,7 +90,7 @@ pub trait IpPayload { fn buffer_len(&self) -> usize; /// Emit this high-level representation into a sequence of octets. - fn emit(&self, repr: &mut IpRepr<&mut [u8]>); + fn emit(&self, ip_repr: &IpRepr, payload: &mut [u8]); } /// A conversion trait for network sockets. diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 3eec2fa..9e4c422 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -241,27 +241,27 @@ impl<'a> TcpSocket<'a> { } /// See [Socket::collect](enum.Socket.html#method.collect). - pub fn collect(&mut self, ip_repr: &IpRepr<&[u8]>) -> Result<(), Error> { - if ip_repr.protocol != IpProtocol::Tcp { return Err(Error::Rejected) } + pub fn collect(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> { + if ip_repr.protocol() != IpProtocol::Tcp { return Err(Error::Rejected) } - let packet = try!(TcpPacket::new(ip_repr.payload)); - let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr, &ip_repr.dst_addr)); + let packet = try!(TcpPacket::new(payload)); + let repr = try!(TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr())); // Reject packets with a wrong destination. if self.local_endpoint.port != repr.dst_port { return Err(Error::Rejected) } if !self.local_endpoint.addr.is_unspecified() && - self.local_endpoint.addr != ip_repr.dst_addr { return Err(Error::Rejected) } + self.local_endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) } // Reject packets from a source to which we aren't connected. if self.remote_endpoint.port != 0 && self.remote_endpoint.port != repr.src_port { return Err(Error::Rejected) } if !self.remote_endpoint.addr.is_unspecified() && - self.remote_endpoint.addr != ip_repr.src_addr { return Err(Error::Rejected) } + self.remote_endpoint.addr != ip_repr.src_addr() { return Err(Error::Rejected) } // Reject packets addressed to a closed socket. if self.state == State::Closed { net_trace!("tcp:{}:{}:{}: packet sent to a closed socket", - self.local_endpoint, ip_repr.src_addr, repr.src_port); + self.local_endpoint, ip_repr.src_addr(), repr.src_port); return Err(Error::Malformed) } @@ -315,8 +315,8 @@ impl<'a> TcpSocket<'a> { (State::Listen, TcpRepr { src_port, dst_port, control: TcpControl::Syn, seq_number, ack_number: None, .. }) => { - self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr, dst_port); - self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr, src_port); + self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), dst_port); + self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), src_port); self.local_seq_no = -seq_number; // FIXME: use something more secure self.remote_seq_no = seq_number + 1; self.set_state(State::SynReceived); @@ -369,7 +369,7 @@ impl<'a> TcpSocket<'a> { /// See [Socket::dispatch](enum.Socket.html#method.dispatch). pub fn dispatch(&mut self, emit: &mut F) -> Result<(), Error> - where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> { + where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> { let mut repr = TcpRepr { src_port: self.local_endpoint.port, dst_port: self.remote_endpoint.port, @@ -413,12 +413,12 @@ impl<'a> TcpSocket<'a> { _ => unreachable!() } - emit(&IpRepr { + let ip_repr = IpRepr::Unspecified { src_addr: self.local_endpoint.addr, dst_addr: self.remote_endpoint.addr, protocol: IpProtocol::Tcp, - payload: &repr as &IpPayload - }) + }; + emit(&ip_repr, &repr) } } @@ -427,9 +427,9 @@ impl<'a> IpPayload for TcpRepr<'a> { self.buffer_len() } - fn emit(&self, repr: &mut IpRepr<&mut [u8]>) { - let mut packet = TcpPacket::new(&mut repr.payload).expect("undersized payload"); - self.emit(&mut packet, &repr.src_addr, &repr.dst_addr) + fn emit(&self, ip_repr: &IpRepr, payload: &mut [u8]) { + let mut packet = TcpPacket::new(payload).expect("undersized payload"); + self.emit(&mut packet, &ip_repr.src_addr(), &ip_repr.dst_addr()) } } @@ -486,37 +486,32 @@ mod test { let mut buffer = vec![0; repr.buffer_len()]; let mut packet = TcpPacket::new(&mut buffer).unwrap(); repr.emit(&mut packet, &REMOTE_IP, &LOCAL_IP); - let result = $socket.collect(&IpRepr { + let ip_repr = IpRepr::Unspecified { src_addr: REMOTE_IP, dst_addr: LOCAL_IP, - protocol: IpProtocol::Tcp, - payload: &packet.into_inner()[..] - }); + protocol: IpProtocol::Tcp + }; + let result = $socket.collect(&ip_repr, &packet.into_inner()[..]); result.expect("send error") }) } macro_rules! recv { ($socket:ident, $expected:expr) => ({ - let result = $socket.dispatch(&mut |repr| { - assert_eq!(repr.protocol, IpProtocol::Tcp); - assert_eq!(repr.src_addr, LOCAL_IP); - assert_eq!(repr.dst_addr, REMOTE_IP); + let result = $socket.dispatch(&mut |ip_repr, payload| { + assert_eq!(ip_repr.protocol(), IpProtocol::Tcp); + assert_eq!(ip_repr.src_addr(), LOCAL_IP); + assert_eq!(ip_repr.dst_addr(), REMOTE_IP); - let mut buffer = vec![0; repr.payload.buffer_len()]; - repr.payload.emit(&mut IpRepr { - src_addr: repr.src_addr, - dst_addr: repr.dst_addr, - protocol: repr.protocol, - payload: &mut buffer[..] - }); + let mut buffer = vec![0; payload.buffer_len()]; + payload.emit(&ip_repr, &mut buffer[..]); let packet = TcpPacket::new(&buffer[..]).unwrap(); - let repr = TcpRepr::parse(&packet, &repr.src_addr, &repr.dst_addr).unwrap(); - assert_eq!(repr, $expected); + let repr = TcpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr()); + assert_eq!(repr, Ok($expected)); Ok(()) }); assert_eq!(result, Ok(())); - let result = $socket.dispatch(&mut |_repr| { + let result = $socket.dispatch(&mut |_repr, _payload| { Ok(()) }); assert_eq!(result, Err(Error::Exhausted)); diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 978c570..3ce46be 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -168,19 +168,19 @@ impl<'a, 'b> UdpSocket<'a, 'b> { } /// See [Socket::collect](enum.Socket.html#method.collect). - pub fn collect(&mut self, ip_repr: &IpRepr<&[u8]>) -> Result<(), Error> { - if ip_repr.protocol != IpProtocol::Udp { return Err(Error::Rejected) } + pub fn collect(&mut self, ip_repr: &IpRepr, payload: &[u8]) -> Result<(), Error> { + if ip_repr.protocol() != IpProtocol::Udp { return Err(Error::Rejected) } - let packet = try!(UdpPacket::new(ip_repr.payload)); - let repr = try!(UdpRepr::parse(&packet, &ip_repr.src_addr, &ip_repr.dst_addr)); + let packet = try!(UdpPacket::new(payload)); + let repr = try!(UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr())); if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) } if !self.endpoint.addr.is_unspecified() { - if self.endpoint.addr != ip_repr.dst_addr { return Err(Error::Rejected) } + if self.endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) } } let packet_buf = try!(self.rx_buffer.enqueue()); - packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr, port: repr.src_port }; + packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port }; packet_buf.size = repr.payload.len(); packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload); net_trace!("udp:{}:{}: collect {} octets", @@ -190,20 +190,21 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// See [Socket::dispatch](enum.Socket.html#method.dispatch). pub fn dispatch(&mut self, emit: &mut F) -> Result<(), Error> - where F: FnMut(&IpRepr<&IpPayload>) -> Result<(), Error> { + where F: FnMut(&IpRepr, &IpPayload) -> Result<(), Error> { let packet_buf = try!(self.tx_buffer.dequeue()); net_trace!("udp:{}:{}: dispatch {} octets", self.endpoint, packet_buf.endpoint, packet_buf.size); - emit(&IpRepr { + let ip_repr = IpRepr::Unspecified { src_addr: self.endpoint.addr, dst_addr: packet_buf.endpoint.addr, - protocol: IpProtocol::Udp, - payload: &UdpRepr { - src_port: self.endpoint.port, - dst_port: packet_buf.endpoint.port, - payload: &packet_buf.as_ref()[..] - } as &IpPayload - }) + protocol: IpProtocol::Udp + }; + let payload = UdpRepr { + src_port: self.endpoint.port, + dst_port: packet_buf.endpoint.port, + payload: &packet_buf.as_ref()[..] + }; + emit(&ip_repr, &payload) } } @@ -212,9 +213,9 @@ impl<'a> IpPayload for UdpRepr<'a> { self.buffer_len() } - fn emit(&self, repr: &mut IpRepr<&mut [u8]>) { - let mut packet = UdpPacket::new(&mut repr.payload).expect("undersized payload"); - self.emit(&mut packet, &repr.src_addr, &repr.dst_addr) + fn emit(&self, repr: &IpRepr, payload: &mut [u8]) { + let mut packet = UdpPacket::new(payload).expect("undersized payload"); + self.emit(&mut packet, &repr.src_addr(), &repr.dst_addr()) } } diff --git a/src/wire/ip.rs b/src/wire/ip.rs index 5d43d70..71626d7 100644 --- a/src/wire/ip.rs +++ b/src/wire/ip.rs @@ -1,6 +1,7 @@ use core::fmt; -use super::Ipv4Address; +use Error; +use super::{Ipv4Address, Ipv4Packet, Ipv4Repr}; enum_with_unknown! { /// Internetworking protocol. @@ -98,6 +99,131 @@ impl fmt::Display for Endpoint { } } +/// An IP packet representation. +/// +/// This enum abstracts the various versions of IP packets. It either contains a concrete +/// high-level representation for some IP protocol version, or an unspecified representation, +/// which permits the `IpAddress::Unspecified` addresses. +#[derive(Debug, Clone)] +pub enum IpRepr { + Unspecified { + src_addr: Address, + dst_addr: Address, + protocol: Protocol + }, + Ipv4(Ipv4Repr), + #[doc(hidden)] + __Nonexhaustive +} + +impl IpRepr { + /// Return the source address. + pub fn src_addr(&self) -> Address { + match self { + &IpRepr::Unspecified { src_addr, .. } => src_addr, + &IpRepr::Ipv4(repr) => Address::Ipv4(repr.src_addr), + &IpRepr::__Nonexhaustive => unreachable!() + } + } + + /// Return the destination address. + pub fn dst_addr(&self) -> Address { + match self { + &IpRepr::Unspecified { dst_addr, .. } => dst_addr, + &IpRepr::Ipv4(repr) => Address::Ipv4(repr.dst_addr), + &IpRepr::__Nonexhaustive => unreachable!() + } + } + + /// Return the protocol. + pub fn protocol(&self) -> Protocol { + match self { + &IpRepr::Unspecified { protocol, .. } => protocol, + &IpRepr::Ipv4(repr) => repr.protocol, + &IpRepr::__Nonexhaustive => unreachable!() + } + } + + /// Convert an unspecified representation into a concrete one, or return + /// `Err(Error::Unaddressable)` if not possible. + /// + /// # Panics + /// This function panics if source and destination addresses belong to different families, + /// or the destination address is unspecified, since this indicates a logic error. + pub fn lower(&self, fallback_src_addrs: &[Address]) -> Result { + match self { + &IpRepr::Unspecified { + src_addr: Address::Ipv4(src_addr), + dst_addr: Address::Ipv4(dst_addr), + protocol + } => { + Ok(IpRepr::Ipv4(Ipv4Repr { + src_addr: src_addr, + dst_addr: dst_addr, + protocol: protocol + })) + } + + &IpRepr::Unspecified { + src_addr: Address::Unspecified, + dst_addr: Address::Ipv4(dst_addr), + protocol + } => { + let mut src_addr = None; + for addr in fallback_src_addrs { + match addr { + &Address::Ipv4(addr) => { + src_addr = Some(addr); + break + } + _ => () + } + } + Ok(IpRepr::Ipv4(Ipv4Repr { + src_addr: try!(src_addr.ok_or(Error::Unaddressable)), + dst_addr: dst_addr, + protocol: protocol + })) + } + + &IpRepr::Unspecified { dst_addr: Address::Unspecified, .. } => + panic!("unspecified destination IP address"), + // &IpRepr::Unspecified { .. } => + // panic!("source and destination IP address families do not match"), + + repr @ &IpRepr::Ipv4(_) => Ok(repr.clone()), + &IpRepr::__Nonexhaustive => unreachable!() + } + } + + /// Return the length of a header that will be emitted from this high-level representation. + /// + /// # Panics + /// This function panics if invoked on an unspecified representation. + pub fn buffer_len(&self) -> usize { + match self { + &IpRepr::Unspecified { .. } => panic!("unspecified IP representation"), + &IpRepr::Ipv4(repr) => repr.buffer_len(), + &IpRepr::__Nonexhaustive => unreachable!() + } + } + + /// Emit this high-level representation into a buffer. + /// + /// # Panics + /// This function panics if invoked on an unspecified representation. + pub fn emit + AsMut<[u8]>>(&self, buffer: T, payload_len: usize) { + match self { + &IpRepr::Unspecified { .. } => panic!("unspecified IP representation"), + &IpRepr::Ipv4(repr) => { + let mut packet = Ipv4Packet::new(buffer).expect("undersized buffer"); + repr.emit(&mut packet, payload_len) + } + &IpRepr::__Nonexhaustive => unreachable!() + } + } +} + pub mod checksum { use byteorder::{ByteOrder, NetworkEndian}; diff --git a/src/wire/mod.rs b/src/wire/mod.rs index bf04854..793a4c8 100644 --- a/src/wire/mod.rs +++ b/src/wire/mod.rs @@ -102,6 +102,7 @@ pub use self::arp::Repr as ArpRepr; pub use self::ip::Protocol as IpProtocol; pub use self::ip::Address as IpAddress; pub use self::ip::Endpoint as IpEndpoint; +pub use self::ip::IpRepr as IpRepr; pub use self::ipv4::Address as Ipv4Address; pub use self::ipv4::Packet as Ipv4Packet;