1
0
forked from M-Labs/nac3

nac3core: better impl of #24

This commit is contained in:
pca006132 2021-09-30 17:07:48 +08:00
parent 928b5bafb5
commit f0fdfe42cb
8 changed files with 75 additions and 92 deletions

View File

@ -15,7 +15,7 @@ use inkwell::{
use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::typecheck::type_inferencer::PrimitiveStore;
use nac3core::{ use nac3core::{
codegen::{CodeGenTask, WithCall, WorkerRegistry, GenCall}, codegen::{CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef}, toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef},
typecheck::typedef::{FunSignature, FuncArg}, typecheck::typedef::{FunSignature, FuncArg},
@ -216,10 +216,9 @@ impl Nac3 {
.write_to_file(module, FileType::Object, Path::new(&format!("{}.o", module.get_name().to_str().unwrap()))) .write_to_file(module, FileType::Object, Path::new(&format!("{}.o", module.get_name().to_str().unwrap())))
.expect("couldn't write module to file"); .expect("couldn't write module to file");
}))); })));
let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new()));
let thread_names: Vec<String> = (0..4).map(|i| format!("module{}", i)).collect(); let thread_names: Vec<String> = (0..4).map(|i| format!("module{}", i)).collect();
let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect(); let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect();
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f, external_codegen); let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);

View File

@ -108,18 +108,17 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
if self.external_codegen.def_list.contains(&fun.1) {
let external_codegen = self.external_codegen.clone();
return external_codegen.run(self, obj, fun, params)
}
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None);
let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap();
let mut task = None; let mut task = None;
let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None);
let symbol = { let symbol = {
// make sure this lock guard is dropped at the end of this scope... // make sure this lock guard is dropped at the end of this scope...
let def = definition.read(); let def = definition.read();
match &*def { match &*def {
TopLevelDef::Function { instance_to_symbol, .. } => { TopLevelDef::Function { instance_to_symbol, codegen_callback, .. } => {
if let Some(callback) = codegen_callback {
return callback.run(self, obj, fun, params);
}
instance_to_symbol.get(&key).cloned() instance_to_symbol.get(&key).cloned()
} }
TopLevelDef::Class { methods, .. } => { TopLevelDef::Class { methods, .. } => {

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{DefinitionId, TopLevelContext, TopLevelDef}, toplevel::{TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier}, typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier},
@ -13,13 +13,13 @@ use inkwell::{
context::Context, context::Context,
module::Module, module::Module,
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
values::{BasicValueEnum, PointerValue}, values::PointerValue,
AddressSpace, AddressSpace,
}; };
use itertools::Itertools; use itertools::Itertools;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use rustpython_parser::ast::{Stmt, StrRef}; use rustpython_parser::ast::{Stmt, StrRef};
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
@ -49,7 +49,6 @@ pub struct CodeGenContext<'ctx, 'a> {
// where continue and break should go to respectively // where continue and break should go to respectively
// the first one is the test_bb, and the second one is bb after the loop // the first one is the test_bb, and the second one is bb after the loop
pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
pub external_codegen: Arc<GenCall>,
} }
type Fp = Box<dyn Fn(&Module) + Send + Sync>; type Fp = Box<dyn Fn(&Module) + Send + Sync>;
@ -68,42 +67,6 @@ impl WithCall {
} }
} }
type GenCallCallback = Box<
dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, BasicValueEnum)>,
(&FunSignature, DefinitionId),
Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>>
+ Send
+ Sync,
>;
pub struct GenCall {
def_list: HashSet<DefinitionId>,
fp: GenCallCallback,
}
impl GenCall {
pub fn new(fp: GenCallCallback, def_list: HashSet<DefinitionId>) -> GenCall {
GenCall { def_list, fp }
}
pub fn run<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
(self.fp)(ctx, obj, fun, args)
}
pub fn need_external_codegen(&self, id: DefinitionId) -> bool {
self.def_list.contains(&id)
}
}
pub struct WorkerRegistry { pub struct WorkerRegistry {
sender: Arc<Sender<Option<CodeGenTask>>>, sender: Arc<Sender<Option<CodeGenTask>>>,
receiver: Arc<Receiver<Option<CodeGenTask>>>, receiver: Arc<Receiver<Option<CodeGenTask>>>,
@ -118,7 +81,6 @@ impl WorkerRegistry {
names: &[&str], names: &[&str],
top_level_ctx: Arc<TopLevelContext>, top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>, f: Arc<WithCall>,
external_codegen: Arc<GenCall>,
) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) { ) -> (Arc<WorkerRegistry>, Vec<thread::JoinHandle<()>>) {
let (sender, receiver) = unbounded(); let (sender, receiver) = unbounded();
let task_count = Mutex::new(0); let task_count = Mutex::new(0);
@ -140,9 +102,8 @@ impl WorkerRegistry {
let registry2 = registry.clone(); let registry2 = registry.clone();
let name = name.to_string(); let name = name.to_string();
let f = f.clone(); let f = f.clone();
let external_codegen = external_codegen.clone();
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
registry.worker_thread(name, top_level_ctx, f, external_codegen); registry.worker_thread(name, top_level_ctx, f);
}); });
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
if let Err(e) = handle.join() { if let Err(e) = handle.join() {
@ -200,22 +161,13 @@ impl WorkerRegistry {
module_name: String, module_name: String,
top_level_ctx: Arc<TopLevelContext>, top_level_ctx: Arc<TopLevelContext>,
f: Arc<WithCall>, f: Arc<WithCall>,
external_codegen: Arc<GenCall>,
) { ) {
let context = Context::create(); let context = Context::create();
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(&module_name); let mut module = context.create_module(&module_name);
while let Some(task) = self.receiver.recv().unwrap() { while let Some(task) = self.receiver.recv().unwrap() {
let result = gen_func( let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone());
&context,
self,
builder,
module,
task,
top_level_ctx.clone(),
external_codegen.clone(),
);
builder = result.0; builder = result.0;
module = result.1; module = result.1;
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -298,7 +250,6 @@ pub fn gen_func<'ctx>(
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
top_level_ctx: Arc<TopLevelContext>, top_level_ctx: Arc<TopLevelContext>,
external_codegen: Arc<GenCall>,
) -> (Builder<'ctx>, Module<'ctx>) { ) -> (Builder<'ctx>, Module<'ctx>) {
// unwrap_or(0) is for unit tests without using rayon // unwrap_or(0) is for unit tests without using rayon
let (mut unifier, primitives) = { let (mut unifier, primitives) = {
@ -396,7 +347,6 @@ pub fn gen_func<'ctx>(
builder, builder,
module, module,
unifier, unifier,
external_codegen,
}; };
let mut returned = false; let mut returned = false;

View File

@ -1,12 +1,21 @@
use crate::{codegen::{CodeGenTask, GenCall, WithCall, WorkerRegistry}, location::Location, symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ use crate::{
codegen::{CodeGenTask, WithCall, WorkerRegistry},
location::Location,
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
}, typecheck::{ },
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
}}; },
};
use indoc::indoc; use indoc::indoc;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::{ast::{StrRef, fold::Fold}, parser::parse_program}; use rustpython_parser::{
ast::{fold::Fold, StrRef},
parser::parse_program,
};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
@ -56,12 +65,6 @@ fn test_primitives() {
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
// let resolver = Arc::new(Mutex::new(Resolver {
// id_to_type: HashMap::new(),
// id_to_def: RwLock::new(HashMap::new()),
// class_names: Default::default(),
// }) as Mutex<dyn SymbolResolver + Send + Sync>);
let resolver = Arc::new(Box::new(Resolver { let resolver = Arc::new(Box::new(Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
@ -109,7 +112,7 @@ fn test_primitives() {
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])), unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
personality_symbol: None personality_symbol: None,
}); });
let unifier = (unifier.get_shared_unifier(), primitives); let unifier = (unifier.get_shared_unifier(), primitives);
@ -182,8 +185,7 @@ fn test_primitives() {
.trim(); .trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
}))); })));
let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
@ -223,6 +225,7 @@ fn test_simple_call() {
instance_to_stmt: HashMap::new(), instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(), instance_to_symbol: HashMap::new(),
resolver: None, resolver: None,
codegen_callback: None,
}))); })));
let resolver = Box::new(Resolver { let resolver = Box::new(Resolver {
@ -298,7 +301,7 @@ fn test_simple_call() {
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])), unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])),
personality_symbol: None personality_symbol: None,
}); });
let unifier = (unifier.get_shared_unifier(), primitives); let unifier = (unifier.get_shared_unifier(), primitives);
@ -347,8 +350,7 @@ fn test_simple_call() {
.trim(); .trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
}))); })));
let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }

View File

@ -100,6 +100,7 @@ impl TopLevelComposer {
instance_to_symbol: [("".into(), name.into())].iter().cloned().collect(), instance_to_symbol: [("".into(), name.into())].iter().cloned().collect(),
var_id: Default::default(), var_id: Default::default(),
resolver: None, resolver: None,
codegen_callback: None,
})), })),
None, None,
)); ));

View File

@ -122,6 +122,7 @@ impl TopLevelComposer {
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(), instance_to_stmt: Default::default(),
resolver, resolver,
codegen_callback: None,
} }
} }

View File

@ -7,6 +7,7 @@ use std::{
sync::Arc, sync::Arc,
}; };
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
use crate::{ use crate::{
@ -16,6 +17,7 @@ use crate::{
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::ast::{self, Stmt, StrRef}; use rustpython_parser::ast::{self, Stmt, StrRef};
use inkwell::values::BasicValueEnum;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize); pub struct DefinitionId(pub usize);
@ -28,6 +30,43 @@ use type_annotation::*;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
type GenCallCallback = Box<
dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, BasicValueEnum)>,
(&FunSignature, DefinitionId),
Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>>
+ Send
+ Sync,
>;
pub struct GenCall {
fp: GenCallCallback,
}
impl GenCall {
pub fn new(fp: GenCallCallback) -> GenCall {
GenCall { fp }
}
pub fn run<'ctx, 'a>(
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
) -> Option<BasicValueEnum<'ctx>> {
(self.fp)(ctx, obj, fun, args)
}
}
impl Debug for GenCall {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct FunInstance { pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>, pub body: Arc<Vec<Stmt<Option<Type>>>>,
@ -78,6 +117,8 @@ pub enum TopLevelDef {
instance_to_stmt: HashMap<String, FunInstance>, instance_to_stmt: HashMap<String, FunInstance>,
// symbol resolver of the module defined the class // symbol resolver of the module defined the class
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
// custom codegen callback
codegen_callback: Option<Arc<GenCall>>
}, },
} }

View File

@ -6,7 +6,7 @@ use rustpython_parser::parser;
use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime};
use nac3core::{ use nac3core::{
codegen::{CodeGenTask, WithCall, GenCall, WorkerRegistry}, codegen::{CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{composer::TopLevelComposer, TopLevelDef}, toplevel::{composer::TopLevelComposer, TopLevelDef},
typecheck::typedef::{FunSignature, FuncArg}, typecheck::typedef::{FunSignature, FuncArg},
@ -144,19 +144,9 @@ fn main() {
// println!("IR:\n{}", module.print_to_string().to_str().unwrap()); // println!("IR:\n{}", module.print_to_string().to_str().unwrap());
}))); })));
let external_codegen = Arc::new(GenCall::new(Box::new(
// example implementation that does sitofp:
// note that a proper implementation may want to check the definition ID
// |ctx, _, _, args| {
// let arg = args[0].1.into_int_value();
// let val = ctx.builder.build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp").into();
// Some(val)
// }
|_, _, _, _| unimplemented!()
), Default::default()));
let threads: Vec<String> = (0..threads).map(|i| format!("module{}", i)).collect(); let threads: Vec<String> = (0..threads).map(|i| format!("module{}", i)).collect();
let threads: Vec<_> = threads.iter().map(|s| s.as_str()).collect(); let threads: Vec<_> = threads.iter().map(|s| s.as_str()).collect();
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);