hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
2 changed files with 39 additions and 10 deletions
Showing only changes of commit d30918bea0 - Show all commits

View File

@ -20,7 +20,10 @@ use itertools::Itertools;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::Stmt; use rustpython_parser::ast::Stmt;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread; use std::thread;
mod expr; mod expr;
@ -65,6 +68,7 @@ impl WithCall {
pub struct WorkerRegistry { pub struct WorkerRegistry {
sender: Arc<Sender<Option<CodeGenTask>>>, sender: Arc<Sender<Option<CodeGenTask>>>,
receiver: Arc<Receiver<Option<CodeGenTask>>>, receiver: Arc<Receiver<Option<CodeGenTask>>>,
panicked: AtomicBool,
task_count: Mutex<usize>, task_count: Mutex<usize>,
thread_count: usize, thread_count: usize,
wait_condvar: Condvar, wait_condvar: Condvar,
@ -75,7 +79,7 @@ impl WorkerRegistry {
names: &[&str], names: &[&str],
top_level_ctx: Arc<TopLevelContext>, top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>, f: Arc<WithCall>,
) -> Arc<WorkerRegistry> { ) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
let (sender, receiver) = unbounded(); let (sender, receiver) = unbounded();
let task_count = Mutex::new(0); let task_count = Mutex::new(0);
let wait_condvar = Condvar::new(); let wait_condvar = Condvar::new();
@ -84,26 +88,44 @@ impl WorkerRegistry {
sender: Arc::new(sender), sender: Arc::new(sender),
receiver: Arc::new(receiver), receiver: Arc::new(receiver),
thread_count: names.len(), thread_count: names.len(),
panicked: AtomicBool::new(false),
task_count, task_count,
wait_condvar, wait_condvar,
}); });
let mut handles = Vec::new();
for name in names.iter() { for name in names.iter() {
let top_level_ctx = top_level_ctx.clone(); let top_level_ctx = top_level_ctx.clone();
let registry = registry.clone(); let registry = registry.clone();
let registry2 = registry.clone();
let name = name.to_string(); let name = name.to_string();
let f = f.clone(); let f = f.clone();
thread::spawn(move || { let handle = thread::spawn(move || {
registry.worker_thread(name, top_level_ctx, f); registry.worker_thread(name, top_level_ctx, f);
}); });
let handle = thread::spawn(move || {
if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() {
eprintln!("Got an error: {}", e);
} else {
eprintln!("Got an unknown error: {:?}", e);
} }
registry registry2.panicked.store(true, Ordering::SeqCst);
registry2.wait_condvar.notify_all();
}
});
handles.push(handle);
}
(registry, handles)
} }
pub fn wait_tasks_complete(&self) { pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
{ {
let mut count = self.task_count.lock(); let mut count = self.task_count.lock();
while *count != 0 { while *count != 0 {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count); self.wait_condvar.wait(&mut count);
} }
} }
@ -113,9 +135,18 @@ impl WorkerRegistry {
{ {
let mut count = self.task_count.lock(); let mut count = self.task_count.lock();
while *count != self.thread_count { while *count != self.thread_count {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count); self.wait_condvar.wait(&mut count);
} }
} }
for handle in handles {
handle.join().unwrap();
}
if self.panicked.load(Ordering::SeqCst) {
panic!("tasks panicked");
}
} }
pub fn add_task(&self, task: CodeGenTask) { pub fn add_task(&self, task: CodeGenTask) {
@ -137,8 +168,6 @@ impl WorkerRegistry {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone()); let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
builder = result.0; builder = result.0;
module = result.1; module = result.1;
println!("{}", *self.task_count.lock());
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
self.wait_condvar.notify_all(); self.wait_condvar.notify_all();
} }

View File

@ -237,7 +237,7 @@ fn test_primitives() {
.trim(); .trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
}))); })));
let registry = WorkerRegistry::create_workers(&threads, top_level, f); let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(); registry.wait_tasks_complete(handles);
} }