diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index e25b6fe1..4df3ee99 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -20,7 +20,10 @@ use itertools::Itertools; use parking_lot::{Condvar, Mutex}; use rustpython_parser::ast::Stmt; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::thread; mod expr; @@ -57,7 +60,7 @@ impl WithCall { WithCall { fp } } - pub fn run<'ctx>(&self, m: &Module<'ctx>) { + pub fn run<'ctx>(&self, m: &Module<'ctx>) { (self.fp)(m) } } @@ -65,6 +68,7 @@ impl WithCall { pub struct WorkerRegistry { sender: Arc>>, receiver: Arc>>, + panicked: AtomicBool, task_count: Mutex, thread_count: usize, wait_condvar: Condvar, @@ -75,7 +79,7 @@ impl WorkerRegistry { names: &[&str], top_level_ctx: Arc, f: Arc, - ) -> Arc { + ) -> (Arc, Vec>) { let (sender, receiver) = unbounded(); let task_count = Mutex::new(0); let wait_condvar = Condvar::new(); @@ -84,26 +88,44 @@ impl WorkerRegistry { sender: Arc::new(sender), receiver: Arc::new(receiver), thread_count: names.len(), + panicked: AtomicBool::new(false), task_count, wait_condvar, }); + let mut handles = Vec::new(); for name in names.iter() { let top_level_ctx = top_level_ctx.clone(); let registry = registry.clone(); + let registry2 = registry.clone(); let name = name.to_string(); let f = f.clone(); - thread::spawn(move || { + let handle = thread::spawn(move || { registry.worker_thread(name, top_level_ctx, f); }); + let handle = thread::spawn(move || { + if let Err(e) = handle.join() { + if let Some(e) = e.downcast_ref::<&'static str>() { + eprintln!("Got an error: {}", e); + } else { + eprintln!("Got an unknown error: {:?}", e); + } + registry2.panicked.store(true, Ordering::SeqCst); + registry2.wait_condvar.notify_all(); + } + }); + handles.push(handle); } - registry + (registry, handles) } - pub fn wait_tasks_complete(&self) { + pub fn wait_tasks_complete(&self, handles: Vec>) { { let mut count = self.task_count.lock(); while *count != 0 { + if self.panicked.load(Ordering::SeqCst) { + break; + } self.wait_condvar.wait(&mut count); } } @@ -113,9 +135,18 @@ impl WorkerRegistry { { let mut count = self.task_count.lock(); while *count != self.thread_count { + if self.panicked.load(Ordering::SeqCst) { + break; + } self.wait_condvar.wait(&mut count); } } + for handle in handles { + handle.join().unwrap(); + } + if self.panicked.load(Ordering::SeqCst) { + panic!("tasks panicked"); + } } pub fn add_task(&self, task: CodeGenTask) { @@ -137,8 +168,6 @@ impl WorkerRegistry { let result = gen_func(&context, builder, module, task, top_level_ctx.clone()); builder = result.0; module = result.1; - - println!("{}", *self.task_count.lock()); *self.task_count.lock() -= 1; self.wait_condvar.notify_all(); } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 54982897..e56c7b45 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -237,7 +237,7 @@ fn test_primitives() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let registry = WorkerRegistry::create_workers(&threads, top_level, f); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); registry.add_task(task); - registry.wait_tasks_complete(); + registry.wait_tasks_complete(handles); }