firmware: fix race condition between TCP listen and accept.

This commit is contained in:
whitequark 2017-01-25 00:17:46 +00:00
parent 2de3770c06
commit 6414e40deb
4 changed files with 135 additions and 104 deletions

View File

@ -1,6 +1,6 @@
use std::io::{self, Write}; use std::io::{self, Write};
use board::{csr, cache}; use board::{csr, cache};
use sched::{Io, TcpSocket}; use sched::{Io, TcpListener, TcpStream};
use analyzer_proto::*; use analyzer_proto::*;
const BUFFER_SIZE: usize = 512 * 1024; const BUFFER_SIZE: usize = 512 * 1024;
@ -40,7 +40,7 @@ fn disarm() {
} }
} }
fn worker(socket: &mut TcpSocket) -> io::Result<()> { fn worker(stream: &mut TcpStream) -> io::Result<()> {
let data = unsafe { &BUFFER.data[..] }; let data = unsafe { &BUFFER.data[..] };
let overflow_occurred = unsafe { csr::rtio_analyzer::message_encoder_overflow_read() != 0 }; let overflow_occurred = unsafe { csr::rtio_analyzer::message_encoder_overflow_read() != 0 };
let total_byte_count = unsafe { csr::rtio_analyzer::dma_byte_count_read() }; let total_byte_count = unsafe { csr::rtio_analyzer::dma_byte_count_read() };
@ -56,12 +56,12 @@ fn worker(socket: &mut TcpSocket) -> io::Result<()> {
}; };
trace!("{:?}", header); trace!("{:?}", header);
try!(header.write_to(socket)); try!(header.write_to(stream));
if wraparound { if wraparound {
try!(socket.write_all(&data[pointer..])); try!(stream.write_all(&data[pointer..]));
try!(socket.write_all(&data[..pointer])); try!(stream.write_all(&data[..pointer]));
} else { } else {
try!(socket.write_all(&data[..pointer])); try!(stream.write_all(&data[..pointer]));
} }
Ok(()) Ok(())
@ -71,20 +71,21 @@ pub fn thread(io: Io) {
// verify that the hack above works // verify that the hack above works
assert!(::core::mem::align_of::<Buffer>() == 64); assert!(::core::mem::align_of::<Buffer>() == 64);
let mut socket = TcpSocket::with_buffer_size(&io, 65535); let listener = TcpListener::new(&io, 65535);
listener.listen(1382).expect("analyzer: cannot listen");
loop { loop {
arm(); arm();
socket.listen(1382).expect("analyzer: cannot listen"); let mut stream = listener.accept().expect("analyzer: cannot accept");
socket.accept().expect("analyzer: cannot accept"); info!("connection from {}", stream.remote_endpoint());
info!("connection from {}", socket.remote_endpoint());
disarm(); disarm();
match worker(&mut socket) { match worker(&mut stream) {
Ok(()) => (), Ok(()) => (),
Err(err) => error!("analyzer aborted: {}", err) Err(err) => error!("analyzer aborted: {}", err)
} }
socket.close().expect("analyzer: cannot close"); stream.close().expect("analyzer: cannot close");
} }
} }

View File

@ -114,7 +114,7 @@ fn worker(socket: &mut UdpSocket) -> io::Result<()> {
} }
pub fn thread(io: Io) { pub fn thread(io: Io) {
let mut socket = UdpSocket::with_buffer_size(&io, 1, 512); let mut socket = UdpSocket::new(&io, 1, 512);
socket.bind(3250); socket.bind(3250);
loop { loop {

View File

@ -1,7 +1,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::mem; use std::mem;
use std::cell::{RefCell, RefMut}; use std::cell::{Cell, RefCell, RefMut};
use std::vec::Vec; use std::vec::Vec;
use std::io::{Read, Write, Result, Error, ErrorKind}; use std::io::{Read, Write, Result, Error, ErrorKind};
use fringe::OwnedStack; use fringe::OwnedStack;
@ -244,7 +244,7 @@ macro_rules! until {
let (sockets, handle) = ($socket.io.sockets.clone(), $socket.handle); let (sockets, handle) = ($socket.io.sockets.clone(), $socket.handle);
$socket.io.until(move || { $socket.io.until(move || {
let mut sockets = borrow_mut!(sockets); let mut sockets = borrow_mut!(sockets);
let $var = sockets.get_mut(handle).as_socket() as &mut $ty; let $var: &mut $ty = sockets.get_mut(handle).as_socket();
$cond $cond
}) })
}) })
@ -260,27 +260,21 @@ pub struct UdpSocket<'a> {
} }
impl<'a> UdpSocket<'a> { impl<'a> UdpSocket<'a> {
pub fn new(io: &'a Io<'a>, rx_buffer: UdpSocketBuffer, tx_buffer: UdpSocketBuffer) -> pub fn new(io: &'a Io<'a>, buffer_depth: usize, buffer_width: usize) -> UdpSocket<'a> {
UdpSocket<'a> {
let handle = borrow_mut!(io.sockets)
.add(UdpSocketLower::new(rx_buffer, tx_buffer));
UdpSocket {
io: io,
handle: handle
}
}
pub fn with_buffer_size(io: &'a Io<'a>, buffer_depth: usize, buffer_width: usize) ->
UdpSocket<'a> {
let mut rx_buffer = vec![]; let mut rx_buffer = vec![];
let mut tx_buffer = vec![]; let mut tx_buffer = vec![];
for _ in 0..buffer_depth { for _ in 0..buffer_depth {
rx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width])); rx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width]));
tx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width])); tx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width]));
} }
Self::new(io, let handle = borrow_mut!(io.sockets)
.add(UdpSocketLower::new(
UdpSocketBuffer::new(rx_buffer), UdpSocketBuffer::new(rx_buffer),
UdpSocketBuffer::new(tx_buffer)) UdpSocketBuffer::new(tx_buffer)));
UdpSocket {
io: io,
handle: handle
}
} }
fn as_lower<'b>(&'b self) -> RefMut<'b, UdpSocketLower> { fn as_lower<'b>(&'b self) -> RefMut<'b, UdpSocketLower> {
@ -326,38 +320,104 @@ type TcpSocketLower = ::smoltcp::socket::TcpSocket<'static>;
pub struct TcpSocketHandle(SocketHandle); pub struct TcpSocketHandle(SocketHandle);
pub struct TcpSocket<'a> { pub struct TcpListener<'a> {
io: &'a Io<'a>,
handle: Cell<SocketHandle>,
buffer_size: Cell<usize>,
endpoint: Cell<IpEndpoint>
}
impl<'a> TcpListener<'a> {
fn new_lower(io: &'a Io<'a>, buffer_size: usize) -> SocketHandle {
let rx_buffer = vec![0; buffer_size];
let tx_buffer = vec![0; buffer_size];
borrow_mut!(io.sockets)
.add(TcpSocketLower::new(
TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer)))
}
pub fn new(io: &'a Io<'a>, buffer_size: usize) -> TcpListener<'a> {
TcpListener {
io: io,
handle: Cell::new(Self::new_lower(io, buffer_size)),
buffer_size: Cell::new(buffer_size),
endpoint: Cell::new(IpEndpoint::default())
}
}
fn as_lower<'b>(&'b self) -> RefMut<'b, TcpSocketLower> {
RefMut::map(borrow_mut!(self.io.sockets),
|sockets| sockets.get_mut(self.handle.get()).as_socket())
}
pub fn is_open(&self) -> bool {
self.as_lower().is_open()
}
pub fn can_accept(&self) -> bool {
self.as_lower().is_active()
}
pub fn local_endpoint(&self) -> IpEndpoint {
self.as_lower().local_endpoint()
}
pub fn listen<T: Into<IpEndpoint>>(&self, endpoint: T) -> Result<()> {
let endpoint = endpoint.into();
try!(self.as_lower().listen(endpoint)
.map_err(|()| Error::new(ErrorKind::Other,
"cannot listen: already connected")));
self.endpoint.set(endpoint);
Ok(())
}
pub fn accept(&self) -> Result<TcpStream<'a>> {
// We're waiting until at least one half of the connection becomes open.
// This handles the case where a remote socket immediately sends a FIN--
// that still counts as accepting even though nothing may be sent.
let (sockets, handle) = (self.io.sockets.clone(), self.handle.get());
try!(self.io.until(move || {
let mut sockets = borrow_mut!(sockets);
let socket: &mut TcpSocketLower = sockets.get_mut(handle).as_socket();
socket.may_send() || socket.may_recv()
}));
let accepted = self.handle.get();
self.handle.set(Self::new_lower(self.io, self.buffer_size.get()));
self.listen(self.endpoint.get()).unwrap();
Ok(TcpStream {
io: self.io,
handle: accepted
})
}
pub fn close(&self) {
self.as_lower().close()
}
}
impl<'a> Drop for TcpListener<'a> {
fn drop(&mut self) {
self.as_lower().close();
borrow_mut!(self.io.sockets).release(self.handle.get())
}
}
pub struct TcpStream<'a> {
io: &'a Io<'a>, io: &'a Io<'a>,
handle: SocketHandle handle: SocketHandle
} }
impl<'a> TcpSocket<'a> { impl<'a> TcpStream<'a> {
pub fn new(io: &'a Io<'a>, rx_buffer: TcpSocketBuffer, tx_buffer: TcpSocketBuffer) ->
TcpSocket<'a> {
let handle = borrow_mut!(io.sockets)
.add(TcpSocketLower::new(rx_buffer, tx_buffer));
TcpSocket {
io: io,
handle: handle
}
}
pub fn with_buffer_size(io: &'a Io<'a>, buffer_size: usize) -> TcpSocket<'a> {
let rx_buffer = vec![0; buffer_size];
let tx_buffer = vec![0; buffer_size];
Self::new(io,
TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer))
}
pub fn into_handle(self) -> TcpSocketHandle { pub fn into_handle(self) -> TcpSocketHandle {
let handle = self.handle; let handle = self.handle;
mem::forget(self); mem::forget(self);
TcpSocketHandle(handle) TcpSocketHandle(handle)
} }
pub fn from_handle(io: &'a Io<'a>, handle: TcpSocketHandle) -> TcpSocket<'a> { pub fn from_handle(io: &'a Io<'a>, handle: TcpSocketHandle) -> TcpStream<'a> {
TcpSocket { TcpStream {
io: io, io: io,
handle: handle.0 handle: handle.0
} }
@ -372,14 +432,6 @@ impl<'a> TcpSocket<'a> {
self.as_lower().is_open() self.as_lower().is_open()
} }
pub fn is_listening(&self) -> bool {
self.as_lower().is_listening()
}
pub fn is_active(&self) -> bool {
self.as_lower().is_active()
}
pub fn may_send(&self) -> bool { pub fn may_send(&self) -> bool {
self.as_lower().may_send() self.as_lower().may_send()
} }
@ -404,19 +456,6 @@ impl<'a> TcpSocket<'a> {
self.as_lower().remote_endpoint() self.as_lower().remote_endpoint()
} }
pub fn listen<T: Into<IpEndpoint>>(&self, endpoint: T) -> Result<()> {
self.as_lower().listen(endpoint)
.map_err(|()| Error::new(ErrorKind::Other,
"cannot listen: already connected"))
}
pub fn accept(&self) -> Result<()> {
// We're waiting until at least one half of the connection becomes open.
// This handles the case where a remote socket immediately sends a FIN--
// that still counts as accepting even though nothing may be sent.
until!(self, TcpSocketLower, |s| s.may_send() || s.may_recv())
}
pub fn close(&self) -> Result<()> { pub fn close(&self) -> Result<()> {
self.as_lower().close(); self.as_lower().close();
try!(until!(self, TcpSocketLower, |s| !s.is_open())); try!(until!(self, TcpSocketLower, |s| !s.is_open()));
@ -427,7 +466,7 @@ impl<'a> TcpSocket<'a> {
} }
} }
impl<'a> Read for TcpSocket<'a> { impl<'a> Read for TcpStream<'a> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> { fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
// fast path // fast path
let result = self.as_lower().recv_slice(buf); let result = self.as_lower().recv_slice(buf);
@ -444,7 +483,7 @@ impl<'a> Read for TcpSocket<'a> {
} }
} }
impl<'a> Write for TcpSocket<'a> { impl<'a> Write for TcpStream<'a> {
fn write(&mut self, buf: &[u8]) -> Result<usize> { fn write(&mut self, buf: &[u8]) -> Result<usize> {
// fast path // fast path
let result = self.as_lower().send_slice(buf); let result = self.as_lower().send_slice(buf);
@ -466,12 +505,9 @@ impl<'a> Write for TcpSocket<'a> {
} }
} }
impl<'a> Drop for TcpSocket<'a> { impl<'a> Drop for TcpStream<'a> {
fn drop(&mut self) { fn drop(&mut self) {
if self.is_open() { self.as_lower().close();
// scheduler will remove any closed sockets with zero references.
self.as_lower().close()
}
borrow_mut!(self.io.sockets).release(self.handle) borrow_mut!(self.io.sockets).release(self.handle)
} }
} }

View File

@ -8,7 +8,7 @@ use logger_artiq::BufferLogger;
use cache::Cache; use cache::Cache;
use urc::Urc; use urc::Urc;
use sched::{ThreadHandle, Io}; use sched::{ThreadHandle, Io};
use sched::{TcpSocket}; use sched::{TcpListener, TcpStream};
use byteorder::{ByteOrder, NetworkEndian}; use byteorder::{ByteOrder, NetworkEndian};
use board; use board;
@ -97,7 +97,7 @@ impl<'a> Drop for Session<'a> {
} }
} }
fn check_magic(stream: &mut TcpSocket) -> io::Result<()> { fn check_magic(stream: &mut TcpStream) -> io::Result<()> {
const MAGIC: &'static [u8] = b"ARTIQ coredev\n"; const MAGIC: &'static [u8] = b"ARTIQ coredev\n";
let mut magic: [u8; 14] = [0; 14]; let mut magic: [u8; 14] = [0; 14];
@ -109,7 +109,7 @@ fn check_magic(stream: &mut TcpSocket) -> io::Result<()> {
} }
} }
fn host_read(stream: &mut TcpSocket) -> io::Result<host::Request> { fn host_read(stream: &mut TcpStream) -> io::Result<host::Request> {
let request = try!(host::Request::read_from(stream)); let request = try!(host::Request::read_from(stream));
match &request { match &request {
&host::Request::LoadKernel(_) => trace!("comm<-host LoadLibrary(...)"), &host::Request::LoadKernel(_) => trace!("comm<-host LoadLibrary(...)"),
@ -198,7 +198,7 @@ fn kern_run(session: &mut Session) -> io::Result<()> {
} }
fn process_host_message(io: &Io, fn process_host_message(io: &Io,
stream: &mut TcpSocket, stream: &mut TcpStream,
session: &mut Session) -> io::Result<()> { session: &mut Session) -> io::Result<()> {
match try!(host_read(stream)) { match try!(host_read(stream)) {
host::Request::Ident => host::Request::Ident =>
@ -338,7 +338,7 @@ fn process_host_message(io: &Io,
} }
fn process_kern_message(io: &Io, fn process_kern_message(io: &Io,
mut stream: Option<&mut TcpSocket>, mut stream: Option<&mut TcpStream>,
session: &mut Session) -> io::Result<bool> { session: &mut Session) -> io::Result<bool> {
kern_recv_notrace(io, |request| { kern_recv_notrace(io, |request| {
match (request, session.kernel_state) { match (request, session.kernel_state) {
@ -535,7 +535,7 @@ fn process_kern_message(io: &Io,
}) })
} }
fn process_kern_queued_rpc(stream: &mut TcpSocket, fn process_kern_queued_rpc(stream: &mut TcpStream,
_session: &mut Session) -> io::Result<()> { _session: &mut Session) -> io::Result<()> {
rpc_queue::dequeue(|slice| { rpc_queue::dequeue(|slice| {
trace!("comm<-kern (async RPC)"); trace!("comm<-kern (async RPC)");
@ -548,7 +548,7 @@ fn process_kern_queued_rpc(stream: &mut TcpSocket,
} }
fn host_kernel_worker(io: &Io, fn host_kernel_worker(io: &Io,
stream: &mut TcpSocket, stream: &mut TcpStream,
congress: &mut Congress) -> io::Result<()> { congress: &mut Congress) -> io::Result<()> {
let mut session = Session::new(congress); let mut session = Session::new(congress);
@ -652,36 +652,30 @@ pub fn thread(io: Io) {
BufferLogger::with_instance(|logger| logger.disable_trace_to_uart()); BufferLogger::with_instance(|logger| logger.disable_trace_to_uart());
const BUFFER_SIZE: usize = 65535; let listener = TcpListener::new(&io, 65535);
let mut listener = TcpSocket::with_buffer_size(&io, BUFFER_SIZE); listener.listen(1381).expect("session: cannot listen");
info!("accepting network sessions"); info!("accepting network sessions");
let mut kernel_thread = None; let mut kernel_thread = None;
loop { loop {
if !listener.is_open() { if listener.can_accept() {
listener.listen(1381).expect("session: cannot listen") let mut stream = listener.accept().expect("session: cannot accept");
} match check_magic(&mut stream) {
if listener.is_active() {
listener.accept().expect("session: cannot accept");
match check_magic(&mut listener) {
Ok(()) => (), Ok(()) => (),
Err(_) => { Err(_) => {
warn!("wrong magic from {}", listener.remote_endpoint()); warn!("wrong magic from {}", stream.remote_endpoint());
listener.close().expect("session: cannot close"); stream.close().expect("session: cannot close");
continue continue
} }
} }
info!("new connection from {}", listener.remote_endpoint()); info!("new connection from {}", stream.remote_endpoint());
let socket = listener.into_handle();
listener = TcpSocket::with_buffer_size(&io, BUFFER_SIZE);
let congress = congress.clone(); let congress = congress.clone();
let stream = stream.into_handle();
respawn(&io, &mut kernel_thread, move |io| { respawn(&io, &mut kernel_thread, move |io| {
let mut congress = borrow_mut!(congress); let mut congress = borrow_mut!(congress);
let mut socket = TcpSocket::from_handle(&io, socket); let mut stream = TcpStream::from_handle(&io, stream);
match host_kernel_worker(&io, &mut socket, &mut *congress) { match host_kernel_worker(&io, &mut stream, &mut *congress) {
Ok(()) => (), Ok(()) => (),
Err(err) => { Err(err) => {
if err.kind() == io::ErrorKind::UnexpectedEof { if err.kind() == io::ErrorKind::UnexpectedEof {