forked from M-Labs/nac3
1
0
Fork 0

threadpool for parallel code generation

This commit is contained in:
pca006132 2021-08-13 14:48:46 +08:00
parent cb01c79603
commit e2adf82229
4 changed files with 184 additions and 80 deletions

View File

@ -6,6 +6,7 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum, Unifier}, typedef::{FunSignature, Type, TypeEnum, Unifier},
}, },
}; };
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
basic_block::BasicBlock, basic_block::BasicBlock,
builder::Builder, builder::Builder,
@ -16,9 +17,11 @@ use inkwell::{
AddressSpace, AddressSpace,
}; };
use itertools::Itertools; use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::Stmt; use rustpython_parser::ast::Stmt;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::thread;
mod expr; mod expr;
mod stmt; mod stmt;
@ -43,6 +46,112 @@ pub struct CodeGenContext<'ctx, 'a> {
pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
} }
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
pub struct WithCall {
fp: Fp,
}
impl WithCall {
pub fn new(fp: Fp) -> WithCall {
WithCall { fp }
}
pub fn run<'ctx>(&self, m: &Module<'ctx>) {
(self.fp)(m)
}
}
pub struct WorkerRegistry {
sender: Arc<Sender<Option<CodeGenTask>>>,
receiver: Arc<Receiver<Option<CodeGenTask>>>,
task_count: Mutex<usize>,
thread_count: usize,
wait_condvar: Condvar,
}
impl WorkerRegistry {
pub fn create_workers(
names: &[&str],
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) -> Arc<WorkerRegistry> {
let (sender, receiver) = unbounded();
let task_count = Mutex::new(0);
let wait_condvar = Condvar::new();
let registry = Arc::new(WorkerRegistry {
sender: Arc::new(sender),
receiver: Arc::new(receiver),
thread_count: names.len(),
task_count,
wait_condvar,
});
for name in names.iter() {
let top_level_ctx = top_level_ctx.clone();
let registry = registry.clone();
let name = name.to_string();
let f = f.clone();
thread::spawn(move || {
registry.worker_thread(name, top_level_ctx, f);
});
}
registry
}
pub fn wait_tasks_complete(&self) {
{
let mut count = self.task_count.lock();
while *count != 0 {
self.wait_condvar.wait(&mut count);
}
}
for _ in 0..self.thread_count {
self.sender.send(None).unwrap();
}
{
let mut count = self.task_count.lock();
while *count != self.thread_count {
self.wait_condvar.wait(&mut count);
}
}
}
pub fn add_task(&self, task: CodeGenTask) {
*self.task_count.lock() += 1;
self.sender.send(Some(task)).unwrap();
}
fn worker_thread(
&self,
module_name: String,
top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>,
) {
let context = Context::create();
let mut builder = context.create_builder();
let mut module = context.create_module(&module_name);
while let Some(task) = self.receiver.recv().unwrap() {
let result = gen_func(&context, builder, module, task, top_level_ctx.clone());
builder = result.0;
module = result.1;
println!("{}", *self.task_count.lock());
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}
// do whatever...
let mut lock = self.task_count.lock();
module.verify().unwrap();
f.run(&module);
*lock += 1;
self.wait_condvar.notify_all();
}
}
pub struct CodeGenTask { pub struct CodeGenTask {
pub subst: Vec<(Type, Type)>, pub subst: Vec<(Type, Type)>,
pub symbol_name: String, pub symbol_name: String,

View File

@ -1,5 +1,6 @@
use super::{gen_func, CodeGenTask}; use super::{CodeGenTask, WorkerRegistry};
use crate::{ use crate::{
codegen::WithCall,
location::Location, location::Location,
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
top_level::{DefinitionId, TopLevelContext}, top_level::{DefinitionId, TopLevelContext},
@ -10,7 +11,6 @@ use crate::{
}, },
}; };
use indoc::indoc; use indoc::indoc;
use inkwell::context::Context;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::{ast::fold::Fold, parser::parse_program}; use rustpython_parser::{ast::fold::Fold, parser::parse_program};
use std::collections::HashMap; use std::collections::HashMap;
@ -109,7 +109,7 @@ impl TestEnvironment {
top_level: TopLevelContext { top_level: TopLevelContext {
definitions: Default::default(), definitions: Default::default(),
unifiers: Default::default(), unifiers: Default::default(),
conetexts: Default::default(), // conetexts: Default::default(),
}, },
function_data: FunctionData { function_data: FunctionData {
resolver, resolver,
@ -140,10 +140,7 @@ impl TestEnvironment {
#[test] #[test]
fn test_primitives() { fn test_primitives() {
let mut env = TestEnvironment::basic_test_env(); let mut env = TestEnvironment::basic_test_env();
let context = Context::create(); let threads = ["test"];
let module = context.create_module("test");
let builder = context.create_builder();
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None }, FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
@ -170,9 +167,8 @@ fn test_primitives() {
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Default::default(), definitions: Default::default(),
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])), unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
conetexts: Default::default(), // conetexts: Default::default(),
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Default::default(),
symbol_name: "testing".to_string(), symbol_name: "testing".to_string(),
@ -182,65 +178,66 @@ fn test_primitives() {
signature, signature,
}; };
let module = gen_func(&context, builder, module, task, top_level); let f = Arc::new(WithCall::new(Box::new(|module| {
// the following IR is equivalent to // the following IR is equivalent to
// ``` // ```
// ; ModuleID = 'test.ll' // ; ModuleID = 'test.ll'
// source_filename = "test" // source_filename = "test"
// //
// ; Function Attrs: norecurse nounwind readnone // ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 { // define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init: // init:
// %add = add i32 %1, %0 // %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1 // %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0 // %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr // ret i32 %ifexpr
// } // }
// //
// attributes #0 = { norecurse nounwind readnone } // attributes #0 = { norecurse nounwind readnone }
// ``` // ```
// after O2 optimization // after O2 optimization
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
define i32 @testing(i32 %0, i32 %1) { define i32 @testing(i32 %0, i32 %1) {
init: init:
%a = alloca i32 %a = alloca i32
store i32 %0, i32* %a store i32 %0, i32* %a
%b = alloca i32 %b = alloca i32
store i32 %1, i32* %b store i32 %1, i32* %b
%tmp = alloca i32 %tmp = alloca i32
%tmp4 = alloca i32 %tmp4 = alloca i32
br label %body br label %body
body: ; preds = %init body: ; preds = %init
%load = load i32, i32* %a %load = load i32, i32* %a
%load1 = load i32, i32* %b %load1 = load i32, i32* %b
%add = add i32 %load, %load1 %add = add i32 %load, %load1
store i32 %add, i32* %tmp store i32 %add, i32* %tmp
%load2 = load i32, i32* %tmp %load2 = load i32, i32* %tmp
%cmp = icmp eq i32 %load2, 1 %cmp = icmp eq i32 %load2, 1
br i1 %cmp, label %then, label %else br i1 %cmp, label %then, label %else
then: ; preds = %body then: ; preds = %body
%load3 = load i32, i32* %a %load3 = load i32, i32* %a
br label %cont br label %cont
else: ; preds = %body else: ; preds = %body
br label %cont br label %cont
cont: ; preds = %else, %then cont: ; preds = %else, %then
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ] %ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
store i32 %ifexpr, i32* %tmp4 store i32 %ifexpr, i32* %tmp4
%load5 = load i32, i32* %tmp4 %load5 = load i32, i32* %tmp4
ret i32 %load5 ret i32 %load5
} }
"} "}
.trim(); .trim();
let ir = module.1.print_to_string().to_string(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
println!("src:\n{}", source); })));
println!("IR:\n{}", ir); let registry = WorkerRegistry::create_workers(&threads, top_level, f);
assert_eq!(expected, ir.trim()); registry.add_task(task);
registry.wait_tasks_complete();
} }

View File

@ -4,7 +4,6 @@ use std::{collections::HashMap, sync::Arc};
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{SharedUnifier, Type, TypeEnum, Unifier};
use crate::symbol_resolver::SymbolResolver; use crate::symbol_resolver::SymbolResolver;
use inkwell::context::Context;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use rustpython_parser::ast::{self, Stmt}; use rustpython_parser::ast::{self, Stmt};
@ -54,17 +53,16 @@ pub enum TopLevelDef {
pub struct TopLevelContext { pub struct TopLevelContext {
pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>, pub definitions: Arc<RwLock<Vec<RwLock<TopLevelDef>>>>,
pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>, pub unifiers: Arc<RwLock<Vec<(SharedUnifier, PrimitiveStore)>>>,
pub conetexts: Arc<RwLock<Vec<Mutex<Context>>>>,
} }
// like adding some info on top of the TopLevelDef for // like adding some info on top of the TopLevelDef for
// later parsing the class bases, method, and function sigatures // later parsing the class bases, method, and function sigatures
pub struct TopLevelDefInfo { pub struct TopLevelDefInfo {
// the definition entry // the definition entry
def: TopLevelDef, def: TopLevelDef,
// the entry in the top_level unifier // the entry in the top_level unifier
ty: Type, ty: Type,
// the ast submitted by applications, primitives and // the ast submitted by applications, primitives and
// class methods will have None value here // class methods will have None value here
ast: Option<ast::Stmt<()>>, ast: Option<ast::Stmt<()>>,
} }
@ -118,7 +116,7 @@ impl TopLevelComposer {
(primitives, unifier) (primitives, unifier)
} }
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) { pub fn new() -> (Vec<(String, DefinitionId, Type)>, Self) {
let primitives = Self::make_primitives(); let primitives = Self::make_primitives();
@ -150,7 +148,7 @@ impl TopLevelComposer {
ty: primitives.0.none, ty: primitives.0.none,
}, },
]; ];
let composer = TopLevelComposer { let composer = TopLevelComposer {
definition_list: definition_list.into(), definition_list: definition_list.into(),
primitives: primitives.0, primitives: primitives.0,
unifier: primitives.1, unifier: primitives.1,
@ -219,7 +217,7 @@ impl TopLevelComposer {
ast: None, ast: None,
ty, ty,
}); });
// parse class def body and register class methods into the def list // parse class def body and register class methods into the def list
// module's symbol resolver would not know the name of the class methods, // module's symbol resolver would not know the name of the class methods,
// thus cannot return their definition_id? so we have to manage it ourselves // thus cannot return their definition_id? so we have to manage it ourselves
@ -228,7 +226,7 @@ impl TopLevelComposer {
if let ast::StmtKind::FunctionDef { name, .. } = &b.node { if let ast::StmtKind::FunctionDef { name, .. } = &b.node {
let fun_name = Self::name_mangling(class_name.clone(), name); let fun_name = Self::name_mangling(class_name.clone(), name);
let def_id = def_list.len(); let def_id = def_list.len();
// add to unifier // add to unifier
let ty = self.unifier.add_ty(TypeEnum::TFunc( let ty = self.unifier.add_ty(TypeEnum::TFunc(
crate::typecheck::typedef::FunSignature { crate::typecheck::typedef::FunSignature {
@ -266,21 +264,21 @@ impl TopLevelComposer {
// move the ast to the entry of the class in the def_list // move the ast to the entry of the class in the def_list
def_list.get_mut(class_def_id).unwrap().ast = Some(ast); def_list.get_mut(class_def_id).unwrap().ast = Some(ast);
// return // return
Ok((class_name, DefinitionId(class_def_id), ty)) Ok((class_name, DefinitionId(class_def_id), ty))
}, },
ast::StmtKind::FunctionDef { name, .. } => { ast::StmtKind::FunctionDef { name, .. } => {
let fun_name = name.to_string(); let fun_name = name.to_string();
// add to the unifier // add to the unifier
let ty = self.unifier.add_ty(TypeEnum::TFunc(crate::typecheck::typedef::FunSignature { let ty = self.unifier.add_ty(TypeEnum::TFunc(crate::typecheck::typedef::FunSignature {
args: Default::default(), args: Default::default(),
ret: self.primitives.none, ret: self.primitives.none,
vars: Default::default(), vars: Default::default(),
})); }));
// add to the definition list // add to the definition list
let mut def_list = self.definition_list.write(); let mut def_list = self.definition_list.write();
def_list.push(TopLevelDefInfo { def_list.push(TopLevelDefInfo {
@ -333,7 +331,7 @@ impl TopLevelComposer {
let (params, let (params,
fields fields
) = if let TypeEnum::TObj { ) = if let TypeEnum::TObj {
// FIXME: this params is immutable, and what // FIXME: this params is immutable, and what
// should the key be, get the original typevar's var_id? // should the key be, get the original typevar's var_id?
params, params,
fields, fields,
@ -346,7 +344,7 @@ impl TopLevelComposer {
// into the `bases` ast node // into the `bases` ast node
for b in bases { for b in bases {
match &b.node { match &b.node {
// typevars bounded to the class, only support things like `class A(Generic[T, V])`, // typevars bounded to the class, only support things like `class A(Generic[T, V])`,
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported // things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript // i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
@ -401,7 +399,7 @@ impl TopLevelComposer {
ast::ExprKind::Subscript {value, slice, ..} => { ast::ExprKind::Subscript {value, slice, ..} => {
unimplemented!() unimplemented!()
}, */ }, */
// base class is possible in other cases, we parse for thr base class // base class is possible in other cases, we parse for thr base class
_ => return Err("not supported".into()) _ => return Err("not supported".into())
} }

View File

@ -100,7 +100,7 @@ impl TestEnvironment {
top_level: TopLevelContext { top_level: TopLevelContext {
definitions: Default::default(), definitions: Default::default(),
unifiers: Default::default(), unifiers: Default::default(),
conetexts: Default::default(), // conetexts: Default::default(),
}, },
unifier, unifier,
function_data: FunctionData { function_data: FunctionData {
@ -259,7 +259,7 @@ impl TestEnvironment {
let top_level = TopLevelContext { let top_level = TopLevelContext {
definitions: Arc::new(RwLock::new(top_level_defs)), definitions: Arc::new(RwLock::new(top_level_defs)),
unifiers: Default::default(), unifiers: Default::default(),
conetexts: Default::default(), // conetexts: Default::default(),
}; };
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {