From 0d5fd1e83d78cb88ddd33fe55e1170269d10dd7a Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 20 Apr 2018 15:26:00 +0000 Subject: [PATCH] runtime: fix race condition in log extraction code paths (#979). The core device used to panic if certain combinations of borrows of the log buffer happened. Now they all use .try_borrow_mut(). --- artiq/firmware/liblogger_artiq/lib.rs | 59 +++++++++++++++++++-------- artiq/firmware/runtime/mgmt.rs | 58 ++++++++++++++------------ artiq/firmware/runtime/sched.rs | 57 +++++++++++++++----------- 3 files changed, 105 insertions(+), 69 deletions(-) diff --git a/artiq/firmware/liblogger_artiq/lib.rs b/artiq/firmware/liblogger_artiq/lib.rs index 2d7dd531b..8ccfdc6fb 100644 --- a/artiq/firmware/liblogger_artiq/lib.rs +++ b/artiq/firmware/liblogger_artiq/lib.rs @@ -5,12 +5,43 @@ extern crate log_buffer; #[macro_use] extern crate board; -use core::cell::{Cell, RefCell}; +use core::cell::{Cell, RefCell, RefMut}; use core::fmt::Write; use log::{Log, LevelFilter}; use log_buffer::LogBuffer; use board::clock; +pub struct LogBufferRef<'a> { + buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>, + old_log_level: LevelFilter +} + +impl<'a> LogBufferRef<'a> { + fn new(buffer: RefMut<'a, LogBuffer<&'static mut [u8]>>) -> LogBufferRef<'a> { + let old_log_level = log::max_level(); + log::set_max_level(LevelFilter::Off); + LogBufferRef { buffer, old_log_level } + } + + pub fn is_empty(&mut self) -> bool { + self.buffer.extract().len() == 0 + } + + pub fn clear(&mut self) { + self.buffer.clear() + } + + pub fn extract(&mut self) -> &str { + self.buffer.extract() + } +} + +impl<'a> Drop for LogBufferRef<'a> { + fn drop(&mut self) { + log::set_max_level(self.old_log_level) + } +} + pub struct BufferLogger { buffer: RefCell>, uart_filter: Cell @@ -40,20 +71,11 @@ impl BufferLogger { f(unsafe { &*LOGGER }) } - pub fn clear(&self) { - self.buffer.borrow_mut().clear() - } - - pub fn is_empty(&self) -> bool { - self.buffer.borrow_mut().extract().len() == 0 - } - - pub fn extract R>(&self, f: F) -> R { - let old_log_level = log::max_level(); - log::set_max_level(LevelFilter::Off); - let result = f(self.buffer.borrow_mut().extract()); - log::set_max_level(old_log_level); - result + pub fn buffer<'a>(&'a self) -> Result, ()> { + self.buffer + .try_borrow_mut() + .map(LogBufferRef::new) + .map_err(|_| ()) } pub fn uart_log_level(&self) -> LevelFilter { @@ -79,9 +101,10 @@ impl Log for BufferLogger { let seconds = timestamp / 1_000_000; let micros = timestamp % 1_000_000; - writeln!(self.buffer.borrow_mut(), - "[{:6}.{:06}s] {:>5}({}): {}", seconds, micros, - record.level(), record.target(), record.args()).unwrap(); + if let Ok(mut buffer) = self.buffer.try_borrow_mut() { + writeln!(buffer, "[{:6}.{:06}s] {:>5}({}): {}", seconds, micros, + record.level(), record.target(), record.args()).unwrap(); + } if record.level() <= self.uart_filter.get() { println!("[{:6}.{:06}s] {:>5}({}): {}", seconds, micros, diff --git a/artiq/firmware/runtime/mgmt.rs b/artiq/firmware/runtime/mgmt.rs index ee3cbe5fb..8877f6221 100644 --- a/artiq/firmware/runtime/mgmt.rs +++ b/artiq/firmware/runtime/mgmt.rs @@ -27,44 +27,48 @@ fn worker(io: &Io, stream: &mut TcpStream) -> io::Result<()> { match Request::read_from(stream)? { Request::GetLog => { BufferLogger::with(|logger| { - logger.extract(|log| { - Reply::LogContent(log).write_to(stream) - }) + let mut buffer = io.until_ok(|| logger.buffer())?; + Reply::LogContent(buffer.extract()).write_to(stream) })?; }, Request::ClearLog => { - BufferLogger::with(|logger| - logger.clear()); + BufferLogger::with(|logger| -> io::Result<()> { + let mut buffer = io.until_ok(|| logger.buffer())?; + Ok(buffer.clear()) + })?; + Reply::Success.write_to(stream)?; }, Request::PullLog => { - loop { - io.until(|| BufferLogger::with(|logger| !logger.is_empty()))?; - - BufferLogger::with(|logger| { + BufferLogger::with(|logger| -> io::Result<()> { + loop { + // Do this *before* acquiring the buffer, since that sets the log level + // to OFF. let log_level = log::max_level(); - logger.extract(|log| { - stream.write_string(log)?; - if log_level == LevelFilter::Trace { - // Hold exclusive access over the logger until we get positive - // acknowledgement; otherwise we get an infinite loop of network - // trace messages being transmitted and causing more network - // trace messages to be emitted. - // - // Any messages unrelated to this management socket that arrive - // while it is flushed are lost, but such is life. - stream.flush() - } else { - Ok(()) - } - })?; + let mut buffer = io.until_ok(|| logger.buffer())?; + if buffer.is_empty() { continue } - Ok(logger.clear()) as io::Result<()> - })?; - } + stream.write_string(buffer.extract())?; + + if log_level == LevelFilter::Trace { + // Hold exclusive access over the logger until we get positive + // acknowledgement; otherwise we get an infinite loop of network + // trace messages being transmitted and causing more network + // trace messages to be emitted. + // + // Any messages unrelated to this management socket that arrive + // while it is flushed are lost, but such is life. + stream.flush()?; + } + + // Clear the log *after* flushing the network buffers, or we're just + // going to resend all the trace messages on the next iteration. + buffer.clear(); + } + })?; }, Request::SetLogFilter(level) => { diff --git a/artiq/firmware/runtime/sched.rs b/artiq/firmware/runtime/sched.rs index 82738e9c4..eeb31ceca 100644 --- a/artiq/firmware/runtime/sched.rs +++ b/artiq/firmware/runtime/sched.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::mem; +use std::result; use std::cell::{Cell, RefCell}; use std::vec::Vec; use std::io::{Read, Write, Result, Error, ErrorKind}; @@ -17,7 +18,7 @@ type SocketSet = ::smoltcp::socket::SocketSet<'static, 'static, 'static>; #[derive(Debug)] struct WaitRequest { - event: Option<*const (Fn() -> bool + 'static)>, + event: Option<*mut FnMut() -> bool>, timeout: Option } @@ -133,26 +134,22 @@ impl Scheduler { self.run_idx = (self.run_idx + 1) % self.threads.len(); let result = { - let mut thread = self.threads[self.run_idx].0.borrow_mut(); - match thread.waiting_for { - _ if thread.interrupted => { - thread.interrupted = false; - thread.generator.resume(WaitResult::Interrupted) - } - WaitRequest { event: None, timeout: None } => - thread.generator.resume(WaitResult::Completed), - WaitRequest { timeout: Some(instant), .. } if now >= instant => - thread.generator.resume(WaitResult::TimedOut), - WaitRequest { event: Some(event), .. } if unsafe { (*event)() } => - thread.generator.resume(WaitResult::Completed), - _ => { - if self.run_idx == start_idx { - // We've checked every thread and none of them are runnable. - break - } else { - continue - } - } + let &mut Thread { ref mut generator, ref mut interrupted, ref waiting_for } = + &mut *self.threads[self.run_idx].0.borrow_mut(); + if *interrupted { + *interrupted = false; + generator.resume(WaitResult::Interrupted) + } else if waiting_for.event.is_none() && waiting_for.timeout.is_none() { + generator.resume(WaitResult::Completed) + } else if waiting_for.timeout.map(|instant| now >= instant).unwrap_or(false) { + generator.resume(WaitResult::TimedOut) + } else if waiting_for.event.map(|event| unsafe { (*event)() }).unwrap_or(false) { + generator.resume(WaitResult::Completed) + } else if self.run_idx == start_idx { + // We've checked every thread and none of them are runnable. + break + } else { + continue } }; @@ -225,13 +222,25 @@ impl<'a> Io<'a> { }) } - pub fn until bool + 'static>(&self, f: F) -> Result<()> { + pub fn until bool>(&self, mut f: F) -> Result<()> { + let f = unsafe { mem::transmute::<&mut FnMut() -> bool, *mut FnMut() -> bool>(&mut f) }; self.suspend(WaitRequest { timeout: None, - event: Some(&f as *const _) + event: Some(f) }) } + pub fn until_ok result::Result>(&self, mut f: F) -> Result { + let mut value = None; + self.until(|| { + if let Ok(result) = f() { + value = Some(result) + } + value.is_some() + })?; + Ok(value.unwrap()) + } + pub fn join(&self, handle: ThreadHandle) -> Result<()> { self.until(move || handle.terminated()) } @@ -250,7 +259,7 @@ macro_rules! until { use ::smoltcp::Error as ErrorLower; -// https://github.com/rust-lang/rust/issues/44057 +// https://github.com/rust-lang/rust/issues/26264 // type ErrorLower = ::smoltcp::Error; type TcpSocketBuffer = ::smoltcp::socket::TcpSocketBuffer<'static>;