hm-inference #6
|
@ -20,7 +20,10 @@ use itertools::Itertools;
|
|||
use parking_lot::{Condvar, Mutex};
|
||||
use rustpython_parser::ast::Stmt;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::thread;
|
||||
|
||||
mod expr;
|
||||
|
@ -57,7 +60,7 @@ impl WithCall {
|
|||
WithCall { fp }
|
||||
}
|
||||
|
||||
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
|
||||
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
|
||||
(self.fp)(m)
|
||||
}
|
||||
}
|
||||
|
@ -65,6 +68,7 @@ impl WithCall {
|
|||
pub struct WorkerRegistry {
|
||||
sender: Arc<Sender<Option<CodeGenTask>>>,
|
||||
receiver: Arc<Receiver<Option<CodeGenTask>>>,
|
||||
panicked: AtomicBool,
|
||||
task_count: Mutex<usize>,
|
||||
thread_count: usize,
|
||||
wait_condvar: Condvar,
|
||||
|
@ -75,7 +79,7 @@ impl WorkerRegistry {
|
|||
names: &[&str],
|
||||
top_level_ctx: Arc<TopLevelContext>,
|
||||
f: Arc<WithCall>,
|
||||
) -> Arc<WorkerRegistry> {
|
||||
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
|
||||
let (sender, receiver) = unbounded();
|
||||
let task_count = Mutex::new(0);
|
||||
let wait_condvar = Condvar::new();
|
||||
|
@ -84,26 +88,44 @@ impl WorkerRegistry {
|
|||
sender: Arc::new(sender),
|
||||
receiver: Arc::new(receiver),
|
||||
thread_count: names.len(),
|
||||
panicked: AtomicBool::new(false),
|
||||
task_count,
|
||||
wait_condvar,
|
||||
});
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for name in names.iter() {
|
||||
let top_level_ctx = top_level_ctx.clone();
|
||||
let registry = registry.clone();
|
||||
let registry2 = registry.clone();
|
||||
let name = name.to_string();
|
||||
let f = f.clone();
|
||||
thread::spawn(move || {
|
||||
let handle = thread::spawn(move || {
|
||||
registry.worker_thread(name, top_level_ctx, f);
|
||||
});
|
||||
let handle = thread::spawn(move || {
|
||||
if let Err(e) = handle.join() {
|
||||
if let Some(e) = e.downcast_ref::<&'static str>() {
|
||||
eprintln!("Got an error: {}", e);
|
||||
} else {
|
||||
eprintln!("Got an unknown error: {:?}", e);
|
||||
}
|
||||
registry2.panicked.store(true, Ordering::SeqCst);
|
||||
registry2.wait_condvar.notify_all();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
registry
|
||||
(registry, handles)
|
||||
}
|
||||
|
||||
pub fn wait_tasks_complete(&self) {
|
||||
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
|
||||
{
|
||||
let mut count = self.task_count.lock();
|
||||
while *count != 0 {
|
||||
if self.panicked.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
self.wait_condvar.wait(&mut count);
|
||||
}
|
||||
}
|
||||
|
@ -113,9 +135,18 @@ impl WorkerRegistry {
|
|||
{
|
||||
let mut count = self.task_count.lock();
|
||||
while *count != self.thread_count {
|
||||
if self.panicked.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
self.wait_condvar.wait(&mut count);
|
||||
}
|
||||
}
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
if self.panicked.load(Ordering::SeqCst) {
|
||||
panic!("tasks panicked");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_task(&self, task: CodeGenTask) {
|
||||
|
@ -137,8 +168,6 @@ impl WorkerRegistry {
|
|||
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
|
||||
builder = result.0;
|
||||
module = result.1;
|
||||
|
||||
println!("{}", *self.task_count.lock());
|
||||
*self.task_count.lock() -= 1;
|
||||
self.wait_condvar.notify_all();
|
||||
}
|
||||
|
|
|
@ -237,7 +237,7 @@ fn test_primitives() {
|
|||
.trim();
|
||||
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
||||
})));
|
||||
let registry = WorkerRegistry::create_workers(&threads, top_level, f);
|
||||
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete();
|
||||
registry.wait_tasks_complete(handles);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue