hm-inference #6
|
@ -20,7 +20,10 @@ use itertools::Itertools;
|
||||||
use parking_lot::{Condvar, Mutex};
|
use parking_lot::{Condvar, Mutex};
|
||||||
use rustpython_parser::ast::Stmt;
|
use rustpython_parser::ast::Stmt;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
|
||||||
mod expr;
|
mod expr;
|
||||||
|
@ -65,6 +68,7 @@ impl WithCall {
|
||||||
pub struct WorkerRegistry {
|
pub struct WorkerRegistry {
|
||||||
sender: Arc<Sender<Option<CodeGenTask>>>,
|
sender: Arc<Sender<Option<CodeGenTask>>>,
|
||||||
receiver: Arc<Receiver<Option<CodeGenTask>>>,
|
receiver: Arc<Receiver<Option<CodeGenTask>>>,
|
||||||
|
panicked: AtomicBool,
|
||||||
task_count: Mutex<usize>,
|
task_count: Mutex<usize>,
|
||||||
thread_count: usize,
|
thread_count: usize,
|
||||||
wait_condvar: Condvar,
|
wait_condvar: Condvar,
|
||||||
|
@ -75,7 +79,7 @@ impl WorkerRegistry {
|
||||||
names: &[&str],
|
names: &[&str],
|
||||||
top_level_ctx: Arc<TopLevelContext>,
|
top_level_ctx: Arc<TopLevelContext>,
|
||||||
f: Arc<WithCall>,
|
f: Arc<WithCall>,
|
||||||
) -> Arc<WorkerRegistry> {
|
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
|
||||||
let (sender, receiver) = unbounded();
|
let (sender, receiver) = unbounded();
|
||||||
let task_count = Mutex::new(0);
|
let task_count = Mutex::new(0);
|
||||||
let wait_condvar = Condvar::new();
|
let wait_condvar = Condvar::new();
|
||||||
|
@ -84,26 +88,44 @@ impl WorkerRegistry {
|
||||||
sender: Arc::new(sender),
|
sender: Arc::new(sender),
|
||||||
receiver: Arc::new(receiver),
|
receiver: Arc::new(receiver),
|
||||||
thread_count: names.len(),
|
thread_count: names.len(),
|
||||||
|
panicked: AtomicBool::new(false),
|
||||||
task_count,
|
task_count,
|
||||||
wait_condvar,
|
wait_condvar,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
for name in names.iter() {
|
for name in names.iter() {
|
||||||
let top_level_ctx = top_level_ctx.clone();
|
let top_level_ctx = top_level_ctx.clone();
|
||||||
let registry = registry.clone();
|
let registry = registry.clone();
|
||||||
|
let registry2 = registry.clone();
|
||||||
let name = name.to_string();
|
let name = name.to_string();
|
||||||
let f = f.clone();
|
let f = f.clone();
|
||||||
thread::spawn(move || {
|
let handle = thread::spawn(move || {
|
||||||
registry.worker_thread(name, top_level_ctx, f);
|
registry.worker_thread(name, top_level_ctx, f);
|
||||||
});
|
});
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
if let Err(e) = handle.join() {
|
||||||
|
if let Some(e) = e.downcast_ref::<&'static str>() {
|
||||||
|
eprintln!("Got an error: {}", e);
|
||||||
|
} else {
|
||||||
|
eprintln!("Got an unknown error: {:?}", e);
|
||||||
}
|
}
|
||||||
registry
|
registry2.panicked.store(true, Ordering::SeqCst);
|
||||||
|
registry2.wait_condvar.notify_all();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
(registry, handles)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_tasks_complete(&self) {
|
pub fn wait_tasks_complete(&self, handles: Vec<thread::JoinHandle<()>>) {
|
||||||
{
|
{
|
||||||
let mut count = self.task_count.lock();
|
let mut count = self.task_count.lock();
|
||||||
while *count != 0 {
|
while *count != 0 {
|
||||||
|
if self.panicked.load(Ordering::SeqCst) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
self.wait_condvar.wait(&mut count);
|
self.wait_condvar.wait(&mut count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,9 +135,18 @@ impl WorkerRegistry {
|
||||||
{
|
{
|
||||||
let mut count = self.task_count.lock();
|
let mut count = self.task_count.lock();
|
||||||
while *count != self.thread_count {
|
while *count != self.thread_count {
|
||||||
|
if self.panicked.load(Ordering::SeqCst) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
self.wait_condvar.wait(&mut count);
|
self.wait_condvar.wait(&mut count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().unwrap();
|
||||||
|
}
|
||||||
|
if self.panicked.load(Ordering::SeqCst) {
|
||||||
|
panic!("tasks panicked");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_task(&self, task: CodeGenTask) {
|
pub fn add_task(&self, task: CodeGenTask) {
|
||||||
|
@ -137,8 +168,6 @@ impl WorkerRegistry {
|
||||||
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
|
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
|
||||||
builder = result.0;
|
builder = result.0;
|
||||||
module = result.1;
|
module = result.1;
|
||||||
|
|
||||||
println!("{}", *self.task_count.lock());
|
|
||||||
*self.task_count.lock() -= 1;
|
*self.task_count.lock() -= 1;
|
||||||
self.wait_condvar.notify_all();
|
self.wait_condvar.notify_all();
|
||||||
}
|
}
|
||||||
|
|
|
@ -237,7 +237,7 @@ fn test_primitives() {
|
||||||
.trim();
|
.trim();
|
||||||
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
|
||||||
})));
|
})));
|
||||||
let registry = WorkerRegistry::create_workers(&threads, top_level, f);
|
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
|
||||||
registry.add_task(task);
|
registry.add_task(task);
|
||||||
registry.wait_tasks_complete();
|
registry.wait_tasks_complete(handles);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue