runtime: fix race condition in log extraction code paths (#979).

The core device used to panic if certain combinations of borrows
of the log buffer happened. Now they all use .try_borrow_mut().
This commit is contained in:
whitequark 2018-04-20 15:26:00 +00:00
parent 4c65fb79b9
commit 84d807a5e4
3 changed files with 110 additions and 71 deletions

View File

@ -4,13 +4,46 @@ extern crate log;
extern crate log_buffer; extern crate log_buffer;
extern crate board; extern crate board;
use core::{mem, ptr}; use core::ptr;
use core::cell::{Cell, RefCell}; use core::cell::{Cell, RefCell, Ref, RefMut};
use core::fmt::Write; use core::fmt::Write;
use log::{Log, LogMetadata, LogRecord, LogLevelFilter, MaxLogLevelFilter}; use log::{Log, LogMetadata, LogRecord, LogLevelFilter, MaxLogLevelFilter};
use log_buffer::LogBuffer; use log_buffer::LogBuffer;
use board::{Console, clock}; use board::{Console, clock};
pub struct LogBufferRef<'a> {
buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>,
filter: Ref<'a, MaxLogLevelFilter>,
old_log_level: LogLevelFilter
}
impl<'a> LogBufferRef<'a> {
fn new(buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>,
filter: Ref<'a, MaxLogLevelFilter>) -> LogBufferRef<'a> {
let old_log_level = filter.get();
filter.set(LogLevelFilter::Off);
LogBufferRef { buffer, filter, old_log_level }
}
pub fn is_empty(&mut self) -> bool {
self.buffer.extract().len() == 0
}
pub fn clear(&mut self) {
self.buffer.clear()
}
pub fn extract(&mut self) -> &str {
self.buffer.extract()
}
}
impl<'a> Drop for LogBufferRef<'a> {
fn drop(&mut self) {
self.filter.set(self.old_log_level)
}
}
pub struct BufferLogger { pub struct BufferLogger {
buffer: RefCell<LogBuffer<&'static mut [u8]>>, buffer: RefCell<LogBuffer<&'static mut [u8]>>,
filter: RefCell<Option<MaxLogLevelFilter>>, filter: RefCell<Option<MaxLogLevelFilter>>,
@ -48,23 +81,15 @@ impl BufferLogger {
} }
pub fn with<R, F: FnOnce(&BufferLogger) -> R>(f: F) -> R { pub fn with<R, F: FnOnce(&BufferLogger) -> R>(f: F) -> R {
f(unsafe { mem::transmute::<*const BufferLogger, &BufferLogger>(LOGGER) }) f(unsafe { &*LOGGER })
} }
pub fn clear(&self) { pub fn buffer<'a>(&'a self) -> Result<LogBufferRef<'a>, ()> {
self.buffer.borrow_mut().clear() let filter = Ref::map(self.filter.borrow(), |f| f.as_ref().unwrap());
} self.buffer
.try_borrow_mut()
pub fn is_empty(&self) -> bool { .map(|buffer| LogBufferRef::new(buffer, filter))
self.buffer.borrow_mut().extract().len() == 0 .map_err(|_| ())
}
pub fn extract<R, F: FnOnce(&str) -> R>(&self, f: F) -> R {
let old_log_level = self.max_log_level();
self.set_max_log_level(LogLevelFilter::Off);
let result = f(self.buffer.borrow_mut().extract());
self.set_max_log_level(old_log_level);
result
} }
pub fn max_log_level(&self) -> LogLevelFilter { pub fn max_log_level(&self) -> LogLevelFilter {
@ -106,9 +131,10 @@ impl Log for BufferLogger {
let seconds = timestamp / 1_000_000; let seconds = timestamp / 1_000_000;
let micros = timestamp % 1_000_000; let micros = timestamp % 1_000_000;
writeln!(self.buffer.borrow_mut(), if let Ok(mut buffer) = self.buffer.try_borrow_mut() {
"[{:6}.{:06}s] {:>5}({}): {}", seconds, micros, writeln!(buffer, "[{:6}.{:06}s] {:>5}({}): {}", seconds, micros,
record.level(), record.target(), record.args()).unwrap(); record.level(), record.target(), record.args()).unwrap();
}
if record.level() <= self.uart_filter.get() { if record.level() <= self.uart_filter.get() {
writeln!(Console, writeln!(Console,

View File

@ -27,26 +27,31 @@ fn worker(io: &Io, stream: &mut TcpStream) -> io::Result<()> {
match Request::read_from(stream)? { match Request::read_from(stream)? {
Request::GetLog => { Request::GetLog => {
BufferLogger::with(|logger| { BufferLogger::with(|logger| {
logger.extract(|log| { let mut buffer = io.until_ok(|| logger.buffer())?;
Reply::LogContent(log).write_to(stream) Reply::LogContent(buffer.extract()).write_to(stream)
})
})?; })?;
}, },
Request::ClearLog => { Request::ClearLog => {
BufferLogger::with(|logger| BufferLogger::with(|logger| -> io::Result<()> {
logger.clear()); let mut buffer = io.until_ok(|| logger.buffer())?;
Ok(buffer.clear())
})?;
Reply::Success.write_to(stream)?; Reply::Success.write_to(stream)?;
}, },
Request::PullLog => { Request::PullLog => {
BufferLogger::with(|logger| -> io::Result<()> {
loop { loop {
io.until(|| BufferLogger::with(|logger| !logger.is_empty()))?; // Do this *before* acquiring the buffer, since that sets the log level
// to OFF.
BufferLogger::with(|logger| {
let log_level = logger.max_log_level(); let log_level = logger.max_log_level();
logger.extract(|log| {
stream.write_string(log)?; let mut buffer = io.until_ok(|| logger.buffer())?;
if buffer.is_empty() { continue }
stream.write_string(buffer.extract())?;
if log_level == LogLevelFilter::Trace { if log_level == LogLevelFilter::Trace {
// Hold exclusive access over the logger until we get positive // Hold exclusive access over the logger until we get positive
@ -56,15 +61,14 @@ fn worker(io: &Io, stream: &mut TcpStream) -> io::Result<()> {
// //
// Any messages unrelated to this management socket that arrive // Any messages unrelated to this management socket that arrive
// while it is flushed are lost, but such is life. // while it is flushed are lost, but such is life.
stream.flush() stream.flush()?;
} else {
Ok(())
} }
})?;
Ok(logger.clear()) as io::Result<()> // Clear the log *after* flushing the network buffers, or we're just
})?; // going to resend all the trace messages on the next iteration.
buffer.clear();
} }
})?;
}, },
Request::SetLogFilter(level) => { Request::SetLogFilter(level) => {

View File

@ -1,6 +1,7 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::mem; use std::mem;
use std::result;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::vec::Vec; use std::vec::Vec;
use std::io::{Read, Write, Result, Error, ErrorKind}; use std::io::{Read, Write, Result, Error, ErrorKind};
@ -17,7 +18,7 @@ type SocketSet = ::smoltcp::socket::SocketSet<'static, 'static, 'static>;
#[derive(Debug)] #[derive(Debug)]
struct WaitRequest { struct WaitRequest {
event: Option<*const (Fn() -> bool + 'static)>, event: Option<*mut FnMut() -> bool>,
timeout: Option<u64> timeout: Option<u64>
} }
@ -133,27 +134,23 @@ impl Scheduler {
self.run_idx = (self.run_idx + 1) % self.threads.len(); self.run_idx = (self.run_idx + 1) % self.threads.len();
let result = { let result = {
let mut thread = self.threads[self.run_idx].0.borrow_mut(); let &mut Thread { ref mut generator, ref mut interrupted, ref waiting_for } =
match thread.waiting_for { &mut *self.threads[self.run_idx].0.borrow_mut();
_ if thread.interrupted => { if *interrupted {
thread.interrupted = false; *interrupted = false;
thread.generator.resume(WaitResult::Interrupted) generator.resume(WaitResult::Interrupted)
} } else if waiting_for.event.is_none() && waiting_for.timeout.is_none() {
WaitRequest { event: None, timeout: None } => generator.resume(WaitResult::Completed)
thread.generator.resume(WaitResult::Completed), } else if waiting_for.timeout.map(|instant| now >= instant).unwrap_or(false) {
WaitRequest { timeout: Some(instant), .. } if now >= instant => generator.resume(WaitResult::TimedOut)
thread.generator.resume(WaitResult::TimedOut), } else if waiting_for.event.map(|event| unsafe { (*event)() }).unwrap_or(false) {
WaitRequest { event: Some(event), .. } if unsafe { (*event)() } => generator.resume(WaitResult::Completed)
thread.generator.resume(WaitResult::Completed), } else if self.run_idx == start_idx {
_ => {
if self.run_idx == start_idx {
// We've checked every thread and none of them are runnable. // We've checked every thread and none of them are runnable.
break break
} else { } else {
continue continue
} }
}
}
}; };
match result { match result {
@ -225,13 +222,25 @@ impl<'a> Io<'a> {
}) })
} }
pub fn until<F: Fn() -> bool + 'static>(&self, f: F) -> Result<()> { pub fn until<F: FnMut() -> bool>(&self, mut f: F) -> Result<()> {
let f = unsafe { mem::transmute::<&mut FnMut() -> bool, *mut FnMut() -> bool>(&mut f) };
self.suspend(WaitRequest { self.suspend(WaitRequest {
timeout: None, timeout: None,
event: Some(&f as *const _) event: Some(f)
}) })
} }
pub fn until_ok<T, E, F: FnMut() -> result::Result<T, E>>(&self, mut f: F) -> Result<T> {
let mut value = None;
self.until(|| {
if let Ok(result) = f() {
value = Some(result)
}
value.is_some()
})?;
Ok(value.unwrap())
}
pub fn join(&self, handle: ThreadHandle) -> Result<()> { pub fn join(&self, handle: ThreadHandle) -> Result<()> {
self.until(move || handle.terminated()) self.until(move || handle.terminated())
} }
@ -250,7 +259,7 @@ macro_rules! until {
use ::smoltcp::Error as ErrorLower; use ::smoltcp::Error as ErrorLower;
// https://github.com/rust-lang/rust/issues/44057 // https://github.com/rust-lang/rust/issues/26264
// type ErrorLower = ::smoltcp::Error; // type ErrorLower = ::smoltcp::Error;
type TcpSocketBuffer = ::smoltcp::socket::TcpSocketBuffer<'static>; type TcpSocketBuffer = ::smoltcp::socket::TcpSocketBuffer<'static>;