runtime: fix race condition in log extraction code paths (#979).

The core device used to panic if certain combinations of borrows
of the log buffer happened. Now they all use .try_borrow_mut().
This commit is contained in:
whitequark 2018-04-20 15:26:00 +00:00
parent b4e3c30d8c
commit 0d5fd1e83d
3 changed files with 105 additions and 69 deletions

View File

@ -5,12 +5,43 @@ extern crate log_buffer;
#[macro_use]
extern crate board;
use core::cell::{Cell, RefCell};
use core::cell::{Cell, RefCell, RefMut};
use core::fmt::Write;
use log::{Log, LevelFilter};
use log_buffer::LogBuffer;
use board::clock;
pub struct LogBufferRef<'a> {
buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>,
old_log_level: LevelFilter
}
impl<'a> LogBufferRef<'a> {
fn new(buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>) -> LogBufferRef<'a> {
let old_log_level = log::max_level();
log::set_max_level(LevelFilter::Off);
LogBufferRef { buffer, old_log_level }
}
pub fn is_empty(&mut self) -> bool {
self.buffer.extract().len() == 0
}
pub fn clear(&mut self) {
self.buffer.clear()
}
pub fn extract(&mut self) -> &str {
self.buffer.extract()
}
}
impl<'a> Drop for LogBufferRef<'a> {
fn drop(&mut self) {
log::set_max_level(self.old_log_level)
}
}
pub struct BufferLogger {
buffer: RefCell<LogBuffer<&'static mut [u8]>>,
uart_filter: Cell<LevelFilter>
@ -40,20 +71,11 @@ impl BufferLogger {
f(unsafe { &*LOGGER })
}
pub fn clear(&self) {
self.buffer.borrow_mut().clear()
}
pub fn is_empty(&self) -> bool {
self.buffer.borrow_mut().extract().len() == 0
}
pub fn extract<R, F: FnOnce(&str) -> R>(&self, f: F) -> R {
let old_log_level = log::max_level();
log::set_max_level(LevelFilter::Off);
let result = f(self.buffer.borrow_mut().extract());
log::set_max_level(old_log_level);
result
pub fn buffer<'a>(&'a self) -> Result<LogBufferRef<'a>, ()> {
self.buffer
.try_borrow_mut()
.map(LogBufferRef::new)
.map_err(|_| ())
}
pub fn uart_log_level(&self) -> LevelFilter {
@ -79,9 +101,10 @@ impl Log for BufferLogger {
let seconds = timestamp / 1_000_000;
let micros = timestamp % 1_000_000;
writeln!(self.buffer.borrow_mut(),
"[{:6}.{:06}s] {:>5}({}): {}", seconds, micros,
if let Ok(mut buffer) = self.buffer.try_borrow_mut() {
writeln!(buffer, "[{:6}.{:06}s] {:>5}({}): {}", seconds, micros,
record.level(), record.target(), record.args()).unwrap();
}
if record.level() <= self.uart_filter.get() {
println!("[{:6}.{:06}s] {:>5}({}): {}", seconds, micros,

View File

@ -27,26 +27,31 @@ fn worker(io: &Io, stream: &mut TcpStream) -> io::Result<()> {
match Request::read_from(stream)? {
Request::GetLog => {
BufferLogger::with(|logger| {
logger.extract(|log| {
Reply::LogContent(log).write_to(stream)
})
let mut buffer = io.until_ok(|| logger.buffer())?;
Reply::LogContent(buffer.extract()).write_to(stream)
})?;
},
Request::ClearLog => {
BufferLogger::with(|logger|
logger.clear());
BufferLogger::with(|logger| -> io::Result<()> {
let mut buffer = io.until_ok(|| logger.buffer())?;
Ok(buffer.clear())
})?;
Reply::Success.write_to(stream)?;
},
Request::PullLog => {
BufferLogger::with(|logger| -> io::Result<()> {
loop {
io.until(|| BufferLogger::with(|logger| !logger.is_empty()))?;
BufferLogger::with(|logger| {
// Do this *before* acquiring the buffer, since that sets the log level
// to OFF.
let log_level = log::max_level();
logger.extract(|log| {
stream.write_string(log)?;
let mut buffer = io.until_ok(|| logger.buffer())?;
if buffer.is_empty() { continue }
stream.write_string(buffer.extract())?;
if log_level == LevelFilter::Trace {
// Hold exclusive access over the logger until we get positive
@ -56,15 +61,14 @@ fn worker(io: &Io, stream: &mut TcpStream) -> io::Result<()> {
//
// Any messages unrelated to this management socket that arrive
// while it is flushed are lost, but such is life.
stream.flush()
} else {
Ok(())
stream.flush()?;
}
})?;
Ok(logger.clear()) as io::Result<()>
})?;
// Clear the log *after* flushing the network buffers, or we're just
// going to resend all the trace messages on the next iteration.
buffer.clear();
}
})?;
},
Request::SetLogFilter(level) => {

View File

@ -1,6 +1,7 @@
#![allow(dead_code)]
use std::mem;
use std::result;
use std::cell::{Cell, RefCell};
use std::vec::Vec;
use std::io::{Read, Write, Result, Error, ErrorKind};
@ -17,7 +18,7 @@ type SocketSet = ::smoltcp::socket::SocketSet<'static, 'static, 'static>;
#[derive(Debug)]
struct WaitRequest {
event: Option<*const (Fn() -> bool + 'static)>,
event: Option<*mut FnMut() -> bool>,
timeout: Option<u64>
}
@ -133,27 +134,23 @@ impl Scheduler {
self.run_idx = (self.run_idx + 1) % self.threads.len();
let result = {
let mut thread = self.threads[self.run_idx].0.borrow_mut();
match thread.waiting_for {
_ if thread.interrupted => {
thread.interrupted = false;
thread.generator.resume(WaitResult::Interrupted)
}
WaitRequest { event: None, timeout: None } =>
thread.generator.resume(WaitResult::Completed),
WaitRequest { timeout: Some(instant), .. } if now >= instant =>
thread.generator.resume(WaitResult::TimedOut),
WaitRequest { event: Some(event), .. } if unsafe { (*event)() } =>
thread.generator.resume(WaitResult::Completed),
_ => {
if self.run_idx == start_idx {
let &mut Thread { ref mut generator, ref mut interrupted, ref waiting_for } =
&mut *self.threads[self.run_idx].0.borrow_mut();
if *interrupted {
*interrupted = false;
generator.resume(WaitResult::Interrupted)
} else if waiting_for.event.is_none() && waiting_for.timeout.is_none() {
generator.resume(WaitResult::Completed)
} else if waiting_for.timeout.map(|instant| now >= instant).unwrap_or(false) {
generator.resume(WaitResult::TimedOut)
} else if waiting_for.event.map(|event| unsafe { (*event)() }).unwrap_or(false) {
generator.resume(WaitResult::Completed)
} else if self.run_idx == start_idx {
// We've checked every thread and none of them are runnable.
break
} else {
continue
}
}
}
};
match result {
@ -225,13 +222,25 @@ impl<'a> Io<'a> {
})
}
pub fn until<F: Fn() -> bool + 'static>(&self, f: F) -> Result<()> {
pub fn until<F: FnMut() -> bool>(&self, mut f: F) -> Result<()> {
let f = unsafe { mem::transmute::<&mut FnMut() -> bool, *mut FnMut() -> bool>(&mut f) };
self.suspend(WaitRequest {
timeout: None,
event: Some(&f as *const _)
event: Some(f)
})
}
pub fn until_ok<T, E, F: FnMut() -> result::Result<T, E>>(&self, mut f: F) -> Result<T> {
let mut value = None;
self.until(|| {
if let Ok(result) = f() {
value = Some(result)
}
value.is_some()
})?;
Ok(value.unwrap())
}
pub fn join(&self, handle: ThreadHandle) -> Result<()> {
self.until(move || handle.terminated())
}
@ -250,7 +259,7 @@ macro_rules! until {
use ::smoltcp::Error as ErrorLower;
// https://github.com/rust-lang/rust/issues/44057
// https://github.com/rust-lang/rust/issues/26264
// type ErrorLower = ::smoltcp::Error;
type TcpSocketBuffer = ::smoltcp::socket::TcpSocketBuffer<'static>;