firmware: fix race condition between TCP listen and accept.

This commit is contained in:
whitequark 2017-01-25 00:17:46 +00:00
parent 2de3770c06
commit 6414e40deb
4 changed files with 135 additions and 104 deletions

View File

@ -1,6 +1,6 @@
use std::io::{self, Write};
use board::{csr, cache};
use sched::{Io, TcpSocket};
use sched::{Io, TcpListener, TcpStream};
use analyzer_proto::*;
const BUFFER_SIZE: usize = 512 * 1024;
@ -40,7 +40,7 @@ fn disarm() {
}
}
fn worker(socket: &mut TcpSocket) -> io::Result<()> {
fn worker(stream: &mut TcpStream) -> io::Result<()> {
let data = unsafe { &BUFFER.data[..] };
let overflow_occurred = unsafe { csr::rtio_analyzer::message_encoder_overflow_read() != 0 };
let total_byte_count = unsafe { csr::rtio_analyzer::dma_byte_count_read() };
@ -56,12 +56,12 @@ fn worker(socket: &mut TcpSocket) -> io::Result<()> {
};
trace!("{:?}", header);
try!(header.write_to(socket));
try!(header.write_to(stream));
if wraparound {
try!(socket.write_all(&data[pointer..]));
try!(socket.write_all(&data[..pointer]));
try!(stream.write_all(&data[pointer..]));
try!(stream.write_all(&data[..pointer]));
} else {
try!(socket.write_all(&data[..pointer]));
try!(stream.write_all(&data[..pointer]));
}
Ok(())
@ -71,20 +71,21 @@ pub fn thread(io: Io) {
// verify that the hack above works
assert!(::core::mem::align_of::<Buffer>() == 64);
let mut socket = TcpSocket::with_buffer_size(&io, 65535);
let listener = TcpListener::new(&io, 65535);
listener.listen(1382).expect("analyzer: cannot listen");
loop {
arm();
socket.listen(1382).expect("analyzer: cannot listen");
socket.accept().expect("analyzer: cannot accept");
info!("connection from {}", socket.remote_endpoint());
let mut stream = listener.accept().expect("analyzer: cannot accept");
info!("connection from {}", stream.remote_endpoint());
disarm();
match worker(&mut socket) {
match worker(&mut stream) {
Ok(()) => (),
Err(err) => error!("analyzer aborted: {}", err)
}
socket.close().expect("analyzer: cannot close");
stream.close().expect("analyzer: cannot close");
}
}

View File

@ -114,7 +114,7 @@ fn worker(socket: &mut UdpSocket) -> io::Result<()> {
}
pub fn thread(io: Io) {
let mut socket = UdpSocket::with_buffer_size(&io, 1, 512);
let mut socket = UdpSocket::new(&io, 1, 512);
socket.bind(3250);
loop {

View File

@ -1,7 +1,7 @@
#![allow(dead_code)]
use std::mem;
use std::cell::{RefCell, RefMut};
use std::cell::{Cell, RefCell, RefMut};
use std::vec::Vec;
use std::io::{Read, Write, Result, Error, ErrorKind};
use fringe::OwnedStack;
@ -244,7 +244,7 @@ macro_rules! until {
let (sockets, handle) = ($socket.io.sockets.clone(), $socket.handle);
$socket.io.until(move || {
let mut sockets = borrow_mut!(sockets);
let $var = sockets.get_mut(handle).as_socket() as &mut $ty;
let $var: &mut $ty = sockets.get_mut(handle).as_socket();
$cond
})
})
@ -260,27 +260,21 @@ pub struct UdpSocket<'a> {
}
impl<'a> UdpSocket<'a> {
pub fn new(io: &'a Io<'a>, rx_buffer: UdpSocketBuffer, tx_buffer: UdpSocketBuffer) ->
UdpSocket<'a> {
let handle = borrow_mut!(io.sockets)
.add(UdpSocketLower::new(rx_buffer, tx_buffer));
UdpSocket {
io: io,
handle: handle
}
}
pub fn with_buffer_size(io: &'a Io<'a>, buffer_depth: usize, buffer_width: usize) ->
UdpSocket<'a> {
pub fn new(io: &'a Io<'a>, buffer_depth: usize, buffer_width: usize) -> UdpSocket<'a> {
let mut rx_buffer = vec![];
let mut tx_buffer = vec![];
for _ in 0..buffer_depth {
rx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width]));
tx_buffer.push(UdpPacketBuffer::new(vec![0; buffer_width]));
}
Self::new(io,
UdpSocketBuffer::new(rx_buffer),
UdpSocketBuffer::new(tx_buffer))
let handle = borrow_mut!(io.sockets)
.add(UdpSocketLower::new(
UdpSocketBuffer::new(rx_buffer),
UdpSocketBuffer::new(tx_buffer)));
UdpSocket {
io: io,
handle: handle
}
}
fn as_lower<'b>(&'b self) -> RefMut<'b, UdpSocketLower> {
@ -326,38 +320,104 @@ type TcpSocketLower = ::smoltcp::socket::TcpSocket<'static>;
pub struct TcpSocketHandle(SocketHandle);
pub struct TcpSocket<'a> {
pub struct TcpListener<'a> {
io: &'a Io<'a>,
handle: Cell<SocketHandle>,
buffer_size: Cell<usize>,
endpoint: Cell<IpEndpoint>
}
impl<'a> TcpListener<'a> {
fn new_lower(io: &'a Io<'a>, buffer_size: usize) -> SocketHandle {
let rx_buffer = vec![0; buffer_size];
let tx_buffer = vec![0; buffer_size];
borrow_mut!(io.sockets)
.add(TcpSocketLower::new(
TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer)))
}
pub fn new(io: &'a Io<'a>, buffer_size: usize) -> TcpListener<'a> {
TcpListener {
io: io,
handle: Cell::new(Self::new_lower(io, buffer_size)),
buffer_size: Cell::new(buffer_size),
endpoint: Cell::new(IpEndpoint::default())
}
}
fn as_lower<'b>(&'b self) -> RefMut<'b, TcpSocketLower> {
RefMut::map(borrow_mut!(self.io.sockets),
|sockets| sockets.get_mut(self.handle.get()).as_socket())
}
pub fn is_open(&self) -> bool {
self.as_lower().is_open()
}
pub fn can_accept(&self) -> bool {
self.as_lower().is_active()
}
pub fn local_endpoint(&self) -> IpEndpoint {
self.as_lower().local_endpoint()
}
pub fn listen<T: Into<IpEndpoint>>(&self, endpoint: T) -> Result<()> {
let endpoint = endpoint.into();
try!(self.as_lower().listen(endpoint)
.map_err(|()| Error::new(ErrorKind::Other,
"cannot listen: already connected")));
self.endpoint.set(endpoint);
Ok(())
}
pub fn accept(&self) -> Result<TcpStream<'a>> {
// We're waiting until at least one half of the connection becomes open.
// This handles the case where a remote socket immediately sends a FIN--
// that still counts as accepting even though nothing may be sent.
let (sockets, handle) = (self.io.sockets.clone(), self.handle.get());
try!(self.io.until(move || {
let mut sockets = borrow_mut!(sockets);
let socket: &mut TcpSocketLower = sockets.get_mut(handle).as_socket();
socket.may_send() || socket.may_recv()
}));
let accepted = self.handle.get();
self.handle.set(Self::new_lower(self.io, self.buffer_size.get()));
self.listen(self.endpoint.get()).unwrap();
Ok(TcpStream {
io: self.io,
handle: accepted
})
}
pub fn close(&self) {
self.as_lower().close()
}
}
impl<'a> Drop for TcpListener<'a> {
fn drop(&mut self) {
self.as_lower().close();
borrow_mut!(self.io.sockets).release(self.handle.get())
}
}
pub struct TcpStream<'a> {
io: &'a Io<'a>,
handle: SocketHandle
}
impl<'a> TcpSocket<'a> {
pub fn new(io: &'a Io<'a>, rx_buffer: TcpSocketBuffer, tx_buffer: TcpSocketBuffer) ->
TcpSocket<'a> {
let handle = borrow_mut!(io.sockets)
.add(TcpSocketLower::new(rx_buffer, tx_buffer));
TcpSocket {
io: io,
handle: handle
}
}
pub fn with_buffer_size(io: &'a Io<'a>, buffer_size: usize) -> TcpSocket<'a> {
let rx_buffer = vec![0; buffer_size];
let tx_buffer = vec![0; buffer_size];
Self::new(io,
TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer))
}
impl<'a> TcpStream<'a> {
pub fn into_handle(self) -> TcpSocketHandle {
let handle = self.handle;
mem::forget(self);
TcpSocketHandle(handle)
}
pub fn from_handle(io: &'a Io<'a>, handle: TcpSocketHandle) -> TcpSocket<'a> {
TcpSocket {
pub fn from_handle(io: &'a Io<'a>, handle: TcpSocketHandle) -> TcpStream<'a> {
TcpStream {
io: io,
handle: handle.0
}
@ -372,14 +432,6 @@ impl<'a> TcpSocket<'a> {
self.as_lower().is_open()
}
pub fn is_listening(&self) -> bool {
self.as_lower().is_listening()
}
pub fn is_active(&self) -> bool {
self.as_lower().is_active()
}
pub fn may_send(&self) -> bool {
self.as_lower().may_send()
}
@ -404,19 +456,6 @@ impl<'a> TcpSocket<'a> {
self.as_lower().remote_endpoint()
}
pub fn listen<T: Into<IpEndpoint>>(&self, endpoint: T) -> Result<()> {
self.as_lower().listen(endpoint)
.map_err(|()| Error::new(ErrorKind::Other,
"cannot listen: already connected"))
}
pub fn accept(&self) -> Result<()> {
// We're waiting until at least one half of the connection becomes open.
// This handles the case where a remote socket immediately sends a FIN--
// that still counts as accepting even though nothing may be sent.
until!(self, TcpSocketLower, |s| s.may_send() || s.may_recv())
}
pub fn close(&self) -> Result<()> {
self.as_lower().close();
try!(until!(self, TcpSocketLower, |s| !s.is_open()));
@ -427,7 +466,7 @@ impl<'a> TcpSocket<'a> {
}
}
impl<'a> Read for TcpSocket<'a> {
impl<'a> Read for TcpStream<'a> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
// fast path
let result = self.as_lower().recv_slice(buf);
@ -444,7 +483,7 @@ impl<'a> Read for TcpSocket<'a> {
}
}
impl<'a> Write for TcpSocket<'a> {
impl<'a> Write for TcpStream<'a> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
// fast path
let result = self.as_lower().send_slice(buf);
@ -466,12 +505,9 @@ impl<'a> Write for TcpSocket<'a> {
}
}
impl<'a> Drop for TcpSocket<'a> {
impl<'a> Drop for TcpStream<'a> {
fn drop(&mut self) {
if self.is_open() {
// scheduler will remove any closed sockets with zero references.
self.as_lower().close()
}
self.as_lower().close();
borrow_mut!(self.io.sockets).release(self.handle)
}
}

View File

@ -8,7 +8,7 @@ use logger_artiq::BufferLogger;
use cache::Cache;
use urc::Urc;
use sched::{ThreadHandle, Io};
use sched::{TcpSocket};
use sched::{TcpListener, TcpStream};
use byteorder::{ByteOrder, NetworkEndian};
use board;
@ -97,7 +97,7 @@ impl<'a> Drop for Session<'a> {
}
}
fn check_magic(stream: &mut TcpSocket) -> io::Result<()> {
fn check_magic(stream: &mut TcpStream) -> io::Result<()> {
const MAGIC: &'static [u8] = b"ARTIQ coredev\n";
let mut magic: [u8; 14] = [0; 14];
@ -109,7 +109,7 @@ fn check_magic(stream: &mut TcpSocket) -> io::Result<()> {
}
}
fn host_read(stream: &mut TcpSocket) -> io::Result<host::Request> {
fn host_read(stream: &mut TcpStream) -> io::Result<host::Request> {
let request = try!(host::Request::read_from(stream));
match &request {
&host::Request::LoadKernel(_) => trace!("comm<-host LoadLibrary(...)"),
@ -198,7 +198,7 @@ fn kern_run(session: &mut Session) -> io::Result<()> {
}
fn process_host_message(io: &Io,
stream: &mut TcpSocket,
stream: &mut TcpStream,
session: &mut Session) -> io::Result<()> {
match try!(host_read(stream)) {
host::Request::Ident =>
@ -338,7 +338,7 @@ fn process_host_message(io: &Io,
}
fn process_kern_message(io: &Io,
mut stream: Option<&mut TcpSocket>,
mut stream: Option<&mut TcpStream>,
session: &mut Session) -> io::Result<bool> {
kern_recv_notrace(io, |request| {
match (request, session.kernel_state) {
@ -535,7 +535,7 @@ fn process_kern_message(io: &Io,
})
}
fn process_kern_queued_rpc(stream: &mut TcpSocket,
fn process_kern_queued_rpc(stream: &mut TcpStream,
_session: &mut Session) -> io::Result<()> {
rpc_queue::dequeue(|slice| {
trace!("comm<-kern (async RPC)");
@ -548,7 +548,7 @@ fn process_kern_queued_rpc(stream: &mut TcpSocket,
}
fn host_kernel_worker(io: &Io,
stream: &mut TcpSocket,
stream: &mut TcpStream,
congress: &mut Congress) -> io::Result<()> {
let mut session = Session::new(congress);
@ -652,36 +652,30 @@ pub fn thread(io: Io) {
BufferLogger::with_instance(|logger| logger.disable_trace_to_uart());
const BUFFER_SIZE: usize = 65535;
let mut listener = TcpSocket::with_buffer_size(&io, BUFFER_SIZE);
let listener = TcpListener::new(&io, 65535);
listener.listen(1381).expect("session: cannot listen");
info!("accepting network sessions");
let mut kernel_thread = None;
loop {
if !listener.is_open() {
listener.listen(1381).expect("session: cannot listen")
}
if listener.is_active() {
listener.accept().expect("session: cannot accept");
match check_magic(&mut listener) {
if listener.can_accept() {
let mut stream = listener.accept().expect("session: cannot accept");
match check_magic(&mut stream) {
Ok(()) => (),
Err(_) => {
warn!("wrong magic from {}", listener.remote_endpoint());
listener.close().expect("session: cannot close");
warn!("wrong magic from {}", stream.remote_endpoint());
stream.close().expect("session: cannot close");
continue
}
}
info!("new connection from {}", listener.remote_endpoint());
let socket = listener.into_handle();
listener = TcpSocket::with_buffer_size(&io, BUFFER_SIZE);
info!("new connection from {}", stream.remote_endpoint());
let congress = congress.clone();
let stream = stream.into_handle();
respawn(&io, &mut kernel_thread, move |io| {
let mut congress = borrow_mut!(congress);
let mut socket = TcpSocket::from_handle(&io, socket);
match host_kernel_worker(&io, &mut socket, &mut *congress) {
let mut stream = TcpStream::from_handle(&io, stream);
match host_kernel_worker(&io, &mut stream, &mut *congress) {
Ok(()) => (),
Err(err) => {
if err.kind() == io::ErrorKind::UnexpectedEof {