hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
2 changed files with 39 additions and 10 deletions
Showing only changes of commit d30918bea0 - Show all commits

View File

@ -20,7 +20,10 @@ use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::Stmt;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
mod expr;
@ -65,6 +68,7 @@ impl WithCall {
pub struct WorkerRegistry {
sender: Arc<Sender<Option<CodeGenTask>>>,
receiver: Arc<Receiver<Option<CodeGenTask>>>,
panicked: AtomicBool,
task_count: Mutex<usize>,
thread_count: usize,
wait_condvar: Condvar,
@ -75,7 +79,7 @@ impl WorkerRegistry {
names: &[&str],
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) -> Arc<WorkerRegistry> {
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
let (sender, receiver) = unbounded();
let task_count = Mutex::new(0);
let wait_condvar = Condvar::new();
@ -84,26 +88,44 @@ impl WorkerRegistry {
sender: Arc::new(sender),
receiver: Arc::new(receiver),
thread_count: names.len(),
panicked: AtomicBool::new(false),
task_count,
wait_condvar,
});
let mut handles = Vec::new();
for name in names.iter() {
let top_level_ctx = top_level_ctx.clone();
let registry = registry.clone();
let registry2 = registry.clone();
let name = name.to_string();
let f = f.clone();
thread::spawn(move || {
let handle = thread::spawn(move || {
registry.worker_thread(name, top_level_ctx, f);
});
let handle = thread::spawn(move || {
if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() {
eprintln!("Got an error: {}", e);
} else {
eprintln!("Got an unknown error: {:?}", e);
}
registry
registry2.panicked.store(true, Ordering::SeqCst);
registry2.wait_condvar.notify_all();
}
});
handles.push(handle);
}
(registry, handles)
}
pub fn wait_tasks_complete(&self) {
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
{
let mut count = self.task_count.lock();
while *count != 0 {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count);
}
}
@ -113,9 +135,18 @@ impl WorkerRegistry {
{
let mut count = self.task_count.lock();
while *count != self.thread_count {
if self.panicked.load(Ordering::SeqCst) {
break;
}
self.wait_condvar.wait(&mut count);
}
}
for handle in handles {
handle.join().unwrap();
}
if self.panicked.load(Ordering::SeqCst) {
panic!("tasks panicked");
}
}
pub fn add_task(&self, task: CodeGenTask) {
@ -137,8 +168,6 @@ impl WorkerRegistry {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
builder = result.0;
module = result.1;
println!("{}", *self.task_count.lock());
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}

View File

@ -237,7 +237,7 @@ fn test_primitives() {
.trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
})));
let registry = WorkerRegistry::create_workers(&threads, top_level, f);
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task);
registry.wait_tasks_complete();
registry.wait_tasks_complete(handles);
}