From 6414e40deba69c597ff4056c804187b160333485 Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 25 Jan 2017 00:17:46 +0000 Subject: [PATCH] firmware: fix race condition between TCP listen and accept. --- artiq/firmware/runtime/analyzer.rs | 25 +++-- artiq/firmware/runtime/moninj.rs | 2 +- artiq/firmware/runtime/sched.rs | 170 +++++++++++++++++------------ artiq/firmware/runtime/session.rs | 42 +++---- 4 files changed, 135 insertions(+), 104 deletions(-) diff --git a/artiq/firmware/runtime/analyzer.rs b/artiq/firmware/runtime/analyzer.rs index 9d4eac0a2..bea395afe 100644 --- a/artiq/firmware/runtime/analyzer.rs +++ b/artiq/firmware/runtime/analyzer.rs @@ -1,6 +1,6 @@ use std::io::{self, Write}; use board::{csr, cache}; -use sched::{Io, TcpSocket}; +use sched::{Io, TcpListener, TcpStream}; use analyzer_proto::*; const BUFFER_SIZE: usize = 512 * 1024; @@ -40,7 +40,7 @@ fn disarm() { } } -fn worker(socket: &mut TcpSocket) -> io::Result<()> { +fn worker(stream: &mut TcpStream) -> io::Result<()> { let data = unsafe { &BUFFER.data[..] }; let overflow_occurred = unsafe { csr::rtio_analyzer::message_encoder_overflow_read() != 0 }; let total_byte_count = unsafe { csr::rtio_analyzer::dma_byte_count_read() }; @@ -56,12 +56,12 @@ fn worker(socket: &mut TcpSocket) -> io::Result<()> { }; trace!("{:?}", header); - try!(header.write_to(socket)); + try!(header.write_to(stream)); if wraparound { - try!(socket.write_all(&data[pointer..])); - try!(socket.write_all(&data[..pointer])); + try!(stream.write_all(&data[pointer..])); + try!(stream.write_all(&data[..pointer])); } else { - try!(socket.write_all(&data[..pointer])); + try!(stream.write_all(&data[..pointer])); } Ok(()) @@ -71,20 +71,21 @@ pub fn thread(io: Io) { // verify that the hack above works assert!(::core::mem::align_of::() == 64); - let mut socket = TcpSocket::with_buffer_size(&io, 65535); + let listener = TcpListener::new(&io, 65535); + listener.listen(1382).expect("analyzer: cannot listen"); + loop { arm(); - socket.listen(1382).expect("analyzer: cannot listen"); - socket.accept().expect("analyzer: cannot accept"); - info!("connection from {}", socket.remote_endpoint()); + let mut stream = listener.accept().expect("analyzer: cannot accept"); + info!("connection from {}", stream.remote_endpoint()); disarm(); - match worker(&mut socket) { + match worker(&mut stream) { Ok(()) => (), Err(err) => error!("analyzer aborted: {}", err) } - socket.close().expect("analyzer: cannot close"); + stream.close().expect("analyzer: cannot close"); } } diff --git a/artiq/firmware/runtime/moninj.rs b/artiq/firmware/runtime/moninj.rs index 9b52d414f..cfb1f16bb 100644 --- a/artiq/firmware/runtime/moninj.rs +++ b/artiq/firmware/runtime/moninj.rs @@ -114,7 +114,7 @@ fn worker(socket: &mut UdpSocket) -> io::Result<()> { } pub fn thread(io: Io) { - let mut socket = UdpSocket::with_buffer_size(&io, 1, 512); + let mut socket = UdpSocket::new(&io, 1, 512); socket.bind(3250); loop { diff --git a/artiq/firmware/runtime/sched.rs b/artiq/firmware/runtime/sched.rs index 22e67895a..7552a298a 100644 --- a/artiq/firmware/runtime/sched.rs +++ b/artiq/firmware/runtime/sched.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::mem; -use std::cell::{RefCell, RefMut}; +use std::cell::{Cell, RefCell, RefMut}; use std::vec::Vec; use std::io::{Read, Write, Result, Error, ErrorKind}; use fringe::OwnedStack; @@ -244,7 +244,7 @@ macro_rules! until { let (sockets, handle) = ($socket.io.sockets.clone(), $socket.handle); $socket.io.until(move || { let mut sockets = borrow_mut!(sockets); - let $var = sockets.get_mut(handle).as_socket() as &mut $ty; + let $var: &mut $ty = sockets.get_mut(handle).as_socket(); $cond }) }) @@ -260,27 +260,21 @@ pub struct UdpSocket<'a> { } impl<'a> UdpSocket<'a> { - pub fn new(io: &'a Io<'a>, rx_buffer: UdpSocketBuffer, tx_buffer: UdpSocketBuffer) -> - UdpSocket<'a> { - let handle = borrow_mut!(io.sockets) - .add(UdpSocketLower::new(rx_buffer, tx_buffer)); - UdpSocket { - io: io, - handle: handle - } - } - - pub fn with_buffer_size(io: &'a Io<'a>, buffer_depth: usize, buffer_width: usize) -> - UdpSocket<'a> { + pub fn new(io: &'a Io<'a>, buffer_depth: usize, buffer_width: usize) -> UdpSocket<'a> { let mut rx_buffer = vec![]; let mut tx_buffer = vec![]; for _ in 0..buffer_depth { rx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width])); tx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width])); } - Self::new(io, - UdpSocketBuffer::new(rx_buffer), - UdpSocketBuffer::new(tx_buffer)) + let handle = borrow_mut!(io.sockets) + .add(UdpSocketLower::new( + UdpSocketBuffer::new(rx_buffer), + UdpSocketBuffer::new(tx_buffer))); + UdpSocket { + io: io, + handle: handle + } } fn as_lower<'b>(&'b self) -> RefMut<'b, UdpSocketLower> { @@ -326,38 +320,104 @@ type TcpSocketLower = ::smoltcp::socket::TcpSocket<'static>; pub struct TcpSocketHandle(SocketHandle); -pub struct TcpSocket<'a> { +pub struct TcpListener<'a> { + io: &'a Io<'a>, + handle: Cell, + buffer_size: Cell, + endpoint: Cell +} + +impl<'a> TcpListener<'a> { + fn new_lower(io: &'a Io<'a>, buffer_size: usize) -> SocketHandle { + let rx_buffer = vec![0; buffer_size]; + let tx_buffer = vec![0; buffer_size]; + borrow_mut!(io.sockets) + .add(TcpSocketLower::new( + TcpSocketBuffer::new(rx_buffer), + TcpSocketBuffer::new(tx_buffer))) + } + + pub fn new(io: &'a Io<'a>, buffer_size: usize) -> TcpListener<'a> { + TcpListener { + io: io, + handle: Cell::new(Self::new_lower(io, buffer_size)), + buffer_size: Cell::new(buffer_size), + endpoint: Cell::new(IpEndpoint::default()) + } + } + + fn as_lower<'b>(&'b self) -> RefMut<'b, TcpSocketLower> { + RefMut::map(borrow_mut!(self.io.sockets), + |sockets| sockets.get_mut(self.handle.get()).as_socket()) + } + + pub fn is_open(&self) -> bool { + self.as_lower().is_open() + } + + pub fn can_accept(&self) -> bool { + self.as_lower().is_active() + } + + pub fn local_endpoint(&self) -> IpEndpoint { + self.as_lower().local_endpoint() + } + + pub fn listen>(&self, endpoint: T) -> Result<()> { + let endpoint = endpoint.into(); + try!(self.as_lower().listen(endpoint) + .map_err(|()| Error::new(ErrorKind::Other, + "cannot listen: already connected"))); + self.endpoint.set(endpoint); + Ok(()) + } + + pub fn accept(&self) -> Result> { + // We're waiting until at least one half of the connection becomes open. + // This handles the case where a remote socket immediately sends a FIN-- + // that still counts as accepting even though nothing may be sent. + let (sockets, handle) = (self.io.sockets.clone(), self.handle.get()); + try!(self.io.until(move || { + let mut sockets = borrow_mut!(sockets); + let socket: &mut TcpSocketLower = sockets.get_mut(handle).as_socket(); + socket.may_send() || socket.may_recv() + })); + + let accepted = self.handle.get(); + self.handle.set(Self::new_lower(self.io, self.buffer_size.get())); + self.listen(self.endpoint.get()).unwrap(); + Ok(TcpStream { + io: self.io, + handle: accepted + }) + } + + pub fn close(&self) { + self.as_lower().close() + } +} + +impl<'a> Drop for TcpListener<'a> { + fn drop(&mut self) { + self.as_lower().close(); + borrow_mut!(self.io.sockets).release(self.handle.get()) + } +} + +pub struct TcpStream<'a> { io: &'a Io<'a>, handle: SocketHandle } -impl<'a> TcpSocket<'a> { - pub fn new(io: &'a Io<'a>, rx_buffer: TcpSocketBuffer, tx_buffer: TcpSocketBuffer) -> - TcpSocket<'a> { - let handle = borrow_mut!(io.sockets) - .add(TcpSocketLower::new(rx_buffer, tx_buffer)); - TcpSocket { - io: io, - handle: handle - } - } - - pub fn with_buffer_size(io: &'a Io<'a>, buffer_size: usize) -> TcpSocket<'a> { - let rx_buffer = vec![0; buffer_size]; - let tx_buffer = vec![0; buffer_size]; - Self::new(io, - TcpSocketBuffer::new(rx_buffer), - TcpSocketBuffer::new(tx_buffer)) - } - +impl<'a> TcpStream<'a> { pub fn into_handle(self) -> TcpSocketHandle { let handle = self.handle; mem::forget(self); TcpSocketHandle(handle) } - pub fn from_handle(io: &'a Io<'a>, handle: TcpSocketHandle) -> TcpSocket<'a> { - TcpSocket { + pub fn from_handle(io: &'a Io<'a>, handle: TcpSocketHandle) -> TcpStream<'a> { + TcpStream { io: io, handle: handle.0 } @@ -372,14 +432,6 @@ impl<'a> TcpSocket<'a> { self.as_lower().is_open() } - pub fn is_listening(&self) -> bool { - self.as_lower().is_listening() - } - - pub fn is_active(&self) -> bool { - self.as_lower().is_active() - } - pub fn may_send(&self) -> bool { self.as_lower().may_send() } @@ -404,19 +456,6 @@ impl<'a> TcpSocket<'a> { self.as_lower().remote_endpoint() } - pub fn listen>(&self, endpoint: T) -> Result<()> { - self.as_lower().listen(endpoint) - .map_err(|()| Error::new(ErrorKind::Other, - "cannot listen: already connected")) - } - - pub fn accept(&self) -> Result<()> { - // We're waiting until at least one half of the connection becomes open. - // This handles the case where a remote socket immediately sends a FIN-- - // that still counts as accepting even though nothing may be sent. - until!(self, TcpSocketLower, |s| s.may_send() || s.may_recv()) - } - pub fn close(&self) -> Result<()> { self.as_lower().close(); try!(until!(self, TcpSocketLower, |s| !s.is_open())); @@ -427,7 +466,7 @@ impl<'a> TcpSocket<'a> { } } -impl<'a> Read for TcpSocket<'a> { +impl<'a> Read for TcpStream<'a> { fn read(&mut self, buf: &mut [u8]) -> Result { // fast path let result = self.as_lower().recv_slice(buf); @@ -444,7 +483,7 @@ impl<'a> Read for TcpSocket<'a> { } } -impl<'a> Write for TcpSocket<'a> { +impl<'a> Write for TcpStream<'a> { fn write(&mut self, buf: &[u8]) -> Result { // fast path let result = self.as_lower().send_slice(buf); @@ -466,12 +505,9 @@ impl<'a> Write for TcpSocket<'a> { } } -impl<'a> Drop for TcpSocket<'a> { +impl<'a> Drop for TcpStream<'a> { fn drop(&mut self) { - if self.is_open() { - // scheduler will remove any closed sockets with zero references. - self.as_lower().close() - } + self.as_lower().close(); borrow_mut!(self.io.sockets).release(self.handle) } } diff --git a/artiq/firmware/runtime/session.rs b/artiq/firmware/runtime/session.rs index cd81e7479..ba1c3544c 100644 --- a/artiq/firmware/runtime/session.rs +++ b/artiq/firmware/runtime/session.rs @@ -8,7 +8,7 @@ use logger_artiq::BufferLogger; use cache::Cache; use urc::Urc; use sched::{ThreadHandle, Io}; -use sched::{TcpSocket}; +use sched::{TcpListener, TcpStream}; use byteorder::{ByteOrder, NetworkEndian}; use board; @@ -97,7 +97,7 @@ impl<'a> Drop for Session<'a> { } } -fn check_magic(stream: &mut TcpSocket) -> io::Result<()> { +fn check_magic(stream: &mut TcpStream) -> io::Result<()> { const MAGIC: &'static [u8] = b"ARTIQ coredev\n"; let mut magic: [u8; 14] = [0; 14]; @@ -109,7 +109,7 @@ fn check_magic(stream: &mut TcpSocket) -> io::Result<()> { } } -fn host_read(stream: &mut TcpSocket) -> io::Result { +fn host_read(stream: &mut TcpStream) -> io::Result { let request = try!(host::Request::read_from(stream)); match &request { &host::Request::LoadKernel(_) => trace!("comm<-host LoadLibrary(...)"), @@ -198,7 +198,7 @@ fn kern_run(session: &mut Session) -> io::Result<()> { } fn process_host_message(io: &Io, - stream: &mut TcpSocket, + stream: &mut TcpStream, session: &mut Session) -> io::Result<()> { match try!(host_read(stream)) { host::Request::Ident => @@ -338,7 +338,7 @@ fn process_host_message(io: &Io, } fn process_kern_message(io: &Io, - mut stream: Option<&mut TcpSocket>, + mut stream: Option<&mut TcpStream>, session: &mut Session) -> io::Result { kern_recv_notrace(io, |request| { match (request, session.kernel_state) { @@ -535,7 +535,7 @@ fn process_kern_message(io: &Io, }) } -fn process_kern_queued_rpc(stream: &mut TcpSocket, +fn process_kern_queued_rpc(stream: &mut TcpStream, _session: &mut Session) -> io::Result<()> { rpc_queue::dequeue(|slice| { trace!("comm<-kern (async RPC)"); @@ -548,7 +548,7 @@ fn process_kern_queued_rpc(stream: &mut TcpSocket, } fn host_kernel_worker(io: &Io, - stream: &mut TcpSocket, + stream: &mut TcpStream, congress: &mut Congress) -> io::Result<()> { let mut session = Session::new(congress); @@ -652,36 +652,30 @@ pub fn thread(io: Io) { BufferLogger::with_instance(|logger| logger.disable_trace_to_uart()); - const BUFFER_SIZE: usize = 65535; - let mut listener = TcpSocket::with_buffer_size(&io, BUFFER_SIZE); + let listener = TcpListener::new(&io, 65535); + listener.listen(1381).expect("session: cannot listen"); info!("accepting network sessions"); let mut kernel_thread = None; loop { - if !listener.is_open() { - listener.listen(1381).expect("session: cannot listen") - } - - if listener.is_active() { - listener.accept().expect("session: cannot accept"); - match check_magic(&mut listener) { + if listener.can_accept() { + let mut stream = listener.accept().expect("session: cannot accept"); + match check_magic(&mut stream) { Ok(()) => (), Err(_) => { - warn!("wrong magic from {}", listener.remote_endpoint()); - listener.close().expect("session: cannot close"); + warn!("wrong magic from {}", stream.remote_endpoint()); + stream.close().expect("session: cannot close"); continue } } - info!("new connection from {}", listener.remote_endpoint()); - - let socket = listener.into_handle(); - listener = TcpSocket::with_buffer_size(&io, BUFFER_SIZE); + info!("new connection from {}", stream.remote_endpoint()); let congress = congress.clone(); + let stream = stream.into_handle(); respawn(&io, &mut kernel_thread, move |io| { let mut congress = borrow_mut!(congress); - let mut socket = TcpSocket::from_handle(&io, socket); - match host_kernel_worker(&io, &mut socket, &mut *congress) { + let mut stream = TcpStream::from_handle(&io, stream); + match host_kernel_worker(&io, &mut stream, &mut *congress) { Ok(()) => (), Err(err) => { if err.kind() == io::ErrorKind::UnexpectedEof {