runtime: fix race condition in log extraction code paths (#979).

The core device used to panic if certain combinations of borrows
of the log buffer happened. Now they all use .try_borrow_mut().
This commit is contained in:
whitequark 2018-04-20 15:26:00 +00:00
parent b4e3c30d8c
commit 0d5fd1e83d
3 changed files with 105 additions and 69 deletions

View File

@ -5,12 +5,43 @@ extern crate log_buffer;
#[macro_use] #[macro_use]
extern crate board; extern crate board;
use core::cell::{Cell, RefCell}; use core::cell::{Cell, RefCell, RefMut};
use core::fmt::Write; use core::fmt::Write;
use log::{Log, LevelFilter}; use log::{Log, LevelFilter};
use log_buffer::LogBuffer; use log_buffer::LogBuffer;
use board::clock; use board::clock;
pub struct LogBufferRef<'a> {
buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>,
old_log_level: LevelFilter
}
impl<'a> LogBufferRef<'a> {
fn new(buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>) -> LogBufferRef<'a> {
let old_log_level = log::max_level();
log::set_max_level(LevelFilter::Off);
LogBufferRef { buffer, old_log_level }
}
pub fn is_empty(&mut self) -> bool {
self.buffer.extract().len() == 0
}
pub fn clear(&mut self) {
self.buffer.clear()
}
pub fn extract(&mut self) -> &str {
self.buffer.extract()
}
}
impl<'a> Drop for LogBufferRef<'a> {
fn drop(&mut self) {
log::set_max_level(self.old_log_level)
}
}
pub struct BufferLogger { pub struct BufferLogger {
buffer: RefCell<LogBuffer<&'static mut [u8]>>, buffer: RefCell<LogBuffer<&'static mut [u8]>>,
uart_filter: Cell<LevelFilter> uart_filter: Cell<LevelFilter>
@ -40,20 +71,11 @@ impl BufferLogger {
f(unsafe { &*LOGGER }) f(unsafe { &*LOGGER })
} }
pub fn clear(&self) { pub fn buffer<'a>(&'a self) -> Result<LogBufferRef<'a>, ()> {
self.buffer.borrow_mut().clear() self.buffer
} .try_borrow_mut()
.map(LogBufferRef::new)
pub fn is_empty(&self) -> bool { .map_err(|_| ())
self.buffer.borrow_mut().extract().len() == 0
}
pub fn extract<R, F: FnOnce(&str) -> R>(&self, f: F) -> R {
let old_log_level = log::max_level();
log::set_max_level(LevelFilter::Off);
let result = f(self.buffer.borrow_mut().extract());
log::set_max_level(old_log_level);
result
} }
pub fn uart_log_level(&self) -> LevelFilter { pub fn uart_log_level(&self) -> LevelFilter {
@ -79,9 +101,10 @@ impl Log for BufferLogger {
let seconds = timestamp / 1_000_000; let seconds = timestamp / 1_000_000;
let micros = timestamp % 1_000_000; let micros = timestamp % 1_000_000;
writeln!(self.buffer.borrow_mut(), if let Ok(mut buffer) = self.buffer.try_borrow_mut() {
"[{:6}.{:06}s] {:>5}({}): {}", seconds, micros, writeln!(buffer, "[{:6}.{:06}s] {:>5}({}): {}", seconds, micros,
record.level(), record.target(), record.args()).unwrap(); record.level(), record.target(), record.args()).unwrap();
}
if record.level() <= self.uart_filter.get() { if record.level() <= self.uart_filter.get() {
println!("[{:6}.{:06}s] {:>5}({}): {}", seconds, micros, println!("[{:6}.{:06}s] {:>5}({}): {}", seconds, micros,

View File

@ -27,44 +27,48 @@ fn worker(io: &Io, stream: &mut TcpStream) -> io::Result<()> {
match Request::read_from(stream)? { match Request::read_from(stream)? {
Request::GetLog => { Request::GetLog => {
BufferLogger::with(|logger| { BufferLogger::with(|logger| {
logger.extract(|log| { let mut buffer = io.until_ok(|| logger.buffer())?;
Reply::LogContent(log).write_to(stream) Reply::LogContent(buffer.extract()).write_to(stream)
})
})?; })?;
}, },
Request::ClearLog => { Request::ClearLog => {
BufferLogger::with(|logger| BufferLogger::with(|logger| -> io::Result<()> {
logger.clear()); let mut buffer = io.until_ok(|| logger.buffer())?;
Ok(buffer.clear())
})?;
Reply::Success.write_to(stream)?; Reply::Success.write_to(stream)?;
}, },
Request::PullLog => { Request::PullLog => {
loop { BufferLogger::with(|logger| -> io::Result<()> {
io.until(|| BufferLogger::with(|logger| !logger.is_empty()))?; loop {
// Do this *before* acquiring the buffer, since that sets the log level
BufferLogger::with(|logger| { // to OFF.
let log_level = log::max_level(); let log_level = log::max_level();
logger.extract(|log| {
stream.write_string(log)?;
if log_level == LevelFilter::Trace { let mut buffer = io.until_ok(|| logger.buffer())?;
// Hold exclusive access over the logger until we get positive if buffer.is_empty() { continue }
// acknowledgement; otherwise we get an infinite loop of network
// trace messages being transmitted and causing more network
// trace messages to be emitted.
//
// Any messages unrelated to this management socket that arrive
// while it is flushed are lost, but such is life.
stream.flush()
} else {
Ok(())
}
})?;
Ok(logger.clear()) as io::Result<()> stream.write_string(buffer.extract())?;
})?;
} if log_level == LevelFilter::Trace {
// Hold exclusive access over the logger until we get positive
// acknowledgement; otherwise we get an infinite loop of network
// trace messages being transmitted and causing more network
// trace messages to be emitted.
//
// Any messages unrelated to this management socket that arrive
// while it is flushed are lost, but such is life.
stream.flush()?;
}
// Clear the log *after* flushing the network buffers, or we're just
// going to resend all the trace messages on the next iteration.
buffer.clear();
}
})?;
}, },
Request::SetLogFilter(level) => { Request::SetLogFilter(level) => {

View File

@ -1,6 +1,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::mem; use std::mem;
use std::result;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::vec::Vec; use std::vec::Vec;
use std::io::{Read, Write, Result, Error, ErrorKind}; use std::io::{Read, Write, Result, Error, ErrorKind};
@ -17,7 +18,7 @@ type SocketSet = ::smoltcp::socket::SocketSet<'static, 'static, 'static>;
#[derive(Debug)] #[derive(Debug)]
struct WaitRequest { struct WaitRequest {
event: Option<*const (Fn() -> bool + 'static)>, event: Option<*mut FnMut() -> bool>,
timeout: Option<u64> timeout: Option<u64>
} }
@ -133,26 +134,22 @@ impl Scheduler {
self.run_idx = (self.run_idx + 1) % self.threads.len(); self.run_idx = (self.run_idx + 1) % self.threads.len();
let result = { let result = {
let mut thread = self.threads[self.run_idx].0.borrow_mut(); let &mut Thread { ref mut generator, ref mut interrupted, ref waiting_for } =
match thread.waiting_for { &mut *self.threads[self.run_idx].0.borrow_mut();
_ if thread.interrupted => { if *interrupted {
thread.interrupted = false; *interrupted = false;
thread.generator.resume(WaitResult::Interrupted) generator.resume(WaitResult::Interrupted)
} } else if waiting_for.event.is_none() && waiting_for.timeout.is_none() {
WaitRequest { event: None, timeout: None } => generator.resume(WaitResult::Completed)
thread.generator.resume(WaitResult::Completed), } else if waiting_for.timeout.map(|instant| now >= instant).unwrap_or(false) {
WaitRequest { timeout: Some(instant), .. } if now >= instant => generator.resume(WaitResult::TimedOut)
thread.generator.resume(WaitResult::TimedOut), } else if waiting_for.event.map(|event| unsafe { (*event)() }).unwrap_or(false) {
WaitRequest { event: Some(event), .. } if unsafe { (*event)() } => generator.resume(WaitResult::Completed)
thread.generator.resume(WaitResult::Completed), } else if self.run_idx == start_idx {
_ => { // We've checked every thread and none of them are runnable.
if self.run_idx == start_idx { break
// We've checked every thread and none of them are runnable. } else {
break continue
} else {
continue
}
}
} }
}; };
@ -225,13 +222,25 @@ impl<'a> Io<'a> {
}) })
} }
pub fn until<F: Fn() -> bool + 'static>(&self, f: F) -> Result<()> { pub fn until<F: FnMut() -> bool>(&self, mut f: F) -> Result<()> {
let f = unsafe { mem::transmute::<&mut FnMut() -> bool, *mut FnMut() -> bool>(&mut f) };
self.suspend(WaitRequest { self.suspend(WaitRequest {
timeout: None, timeout: None,
event: Some(&f as *const _) event: Some(f)
}) })
} }
pub fn until_ok<T, E, F: FnMut() -> result::Result<T, E>>(&self, mut f: F) -> Result<T> {
let mut value = None;
self.until(|| {
if let Ok(result) = f() {
value = Some(result)
}
value.is_some()
})?;
Ok(value.unwrap())
}
pub fn join(&self, handle: ThreadHandle) -> Result<()> { pub fn join(&self, handle: ThreadHandle) -> Result<()> {
self.until(move || handle.terminated()) self.until(move || handle.terminated())
} }
@ -250,7 +259,7 @@ macro_rules! until {
use ::smoltcp::Error as ErrorLower; use ::smoltcp::Error as ErrorLower;
// https://github.com/rust-lang/rust/issues/44057 // https://github.com/rust-lang/rust/issues/26264
// type ErrorLower = ::smoltcp::Error; // type ErrorLower = ::smoltcp::Error;
type TcpSocketBuffer = ::smoltcp::socket::TcpSocketBuffer<'static>; type TcpSocketBuffer = ::smoltcp::socket::TcpSocketBuffer<'static>;