diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index dbdb9dd51..1e0d4b374 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -15,7 +15,7 @@ use inkwell::{ use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry, GenCall}, + codegen::{CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, @@ -216,10 +216,9 @@ impl Nac3 { .write_to_file(module, FileType::Object, Path::new(&format!("{}.o", module.get_name().to_str().unwrap()))) .expect("couldn't write module to file"); }))); - let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); let thread_names: Vec = (0..4).map(|i| format!("module{}", i)).collect(); let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect(); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f, external_codegen); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f); registry.add_task(task); registry.wait_tasks_complete(handles); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0ac2f856e..34407bc01 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -108,18 +108,17 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fun: (&FunSignature, DefinitionId), params: Vec<(Option, BasicValueEnum<'ctx>)>, ) -> Option> { - if self.external_codegen.def_list.contains(&fun.1) { - let external_codegen = self.external_codegen.clone(); - return external_codegen.run(self, obj, fun, params) - } - let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); let mut task = None; + let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let symbol = { // make sure this lock guard is dropped at the end of this scope... let def = definition.read(); match &*def { - TopLevelDef::Function { instance_to_symbol, .. } => { + TopLevelDef::Function { instance_to_symbol, codegen_callback, .. } => { + if let Some(callback) = codegen_callback { + return callback.run(self, obj, fun, params); + } instance_to_symbol.get(&key).cloned() } TopLevelDef::Class { methods, .. } => { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 2befa6c9f..e01c130ba 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,6 +1,6 @@ use crate::{ symbol_resolver::SymbolResolver, - toplevel::{DefinitionId, TopLevelContext, TopLevelDef}, + toplevel::{TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier}, @@ -13,13 +13,13 @@ use inkwell::{ context::Context, module::Module, types::{BasicType, BasicTypeEnum}, - values::{BasicValueEnum, PointerValue}, + values::PointerValue, AddressSpace, }; use itertools::Itertools; use parking_lot::{Condvar, Mutex}; use rustpython_parser::ast::{Stmt, StrRef}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -49,7 +49,6 @@ pub struct CodeGenContext<'ctx, 'a> { // where continue and break should go to respectively // the first one is the test_bb, and the second one is bb after the loop pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, - pub external_codegen: Arc, } type Fp = Box; @@ -68,42 +67,6 @@ impl WithCall { } } -type GenCallCallback = Box< - dyn for<'ctx, 'a> Fn( - &mut CodeGenContext<'ctx, 'a>, - Option<(Type, BasicValueEnum)>, - (&FunSignature, DefinitionId), - Vec<(Option, BasicValueEnum<'ctx>)>, - ) -> Option> - + Send - + Sync, ->; - -pub struct GenCall { - def_list: HashSet, - fp: GenCallCallback, -} - -impl GenCall { - pub fn new(fp: GenCallCallback, def_list: HashSet) -> GenCall { - GenCall { def_list, fp } - } - - pub fn run<'ctx, 'a>( - &self, - ctx: &mut CodeGenContext<'ctx, 'a>, - obj: Option<(Type, BasicValueEnum<'ctx>)>, - fun: (&FunSignature, DefinitionId), - args: Vec<(Option, BasicValueEnum<'ctx>)>, - ) -> Option> { - (self.fp)(ctx, obj, fun, args) - } - - pub fn need_external_codegen(&self, id: DefinitionId) -> bool { - self.def_list.contains(&id) - } -} - pub struct WorkerRegistry { sender: Arc>>, receiver: Arc>>, @@ -118,7 +81,6 @@ impl WorkerRegistry { names: &[&str], top_level_ctx: Arc, f: Arc, - external_codegen: Arc, ) -> (Arc, Vec>) { let (sender, receiver) = unbounded(); let task_count = Mutex::new(0); @@ -140,9 +102,8 @@ impl WorkerRegistry { let registry2 = registry.clone(); let name = name.to_string(); let f = f.clone(); - let external_codegen = external_codegen.clone(); let handle = thread::spawn(move || { - registry.worker_thread(name, top_level_ctx, f, external_codegen); + registry.worker_thread(name, top_level_ctx, f); }); let handle = thread::spawn(move || { if let Err(e) = handle.join() { @@ -200,22 +161,13 @@ impl WorkerRegistry { module_name: String, top_level_ctx: Arc, f: Arc, - external_codegen: Arc, ) { let context = Context::create(); let mut builder = context.create_builder(); let mut module = context.create_module(&module_name); while let Some(task) = self.receiver.recv().unwrap() { - let result = gen_func( - &context, - self, - builder, - module, - task, - top_level_ctx.clone(), - external_codegen.clone(), - ); + let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone()); builder = result.0; module = result.1; *self.task_count.lock() -= 1; @@ -298,7 +250,6 @@ pub fn gen_func<'ctx>( module: Module<'ctx>, task: CodeGenTask, top_level_ctx: Arc, - external_codegen: Arc, ) -> (Builder<'ctx>, Module<'ctx>) { // unwrap_or(0) is for unit tests without using rayon let (mut unifier, primitives) = { @@ -396,7 +347,6 @@ pub fn gen_func<'ctx>( builder, module, unifier, - external_codegen, }; let mut returned = false; diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ad524a1cf..3f34c6642 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,12 +1,21 @@ -use crate::{codegen::{CodeGenTask, GenCall, WithCall, WorkerRegistry}, location::Location, symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ +use crate::{ + codegen::{CodeGenTask, WithCall, WorkerRegistry}, + location::Location, + symbol_resolver::{SymbolResolver, SymbolValue}, + toplevel::{ composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, - }, typecheck::{ + }, + typecheck::{ type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, - }}; + }, +}; use indoc::indoc; use parking_lot::RwLock; -use rustpython_parser::{ast::{StrRef, fold::Fold}, parser::parse_program}; +use rustpython_parser::{ + ast::{fold::Fold, StrRef}, + parser::parse_program, +}; use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -56,12 +65,6 @@ fn test_primitives() { let top_level = Arc::new(composer.make_top_level_context()); unifier.top_level = Some(top_level.clone()); - // let resolver = Arc::new(Mutex::new(Resolver { - // id_to_type: HashMap::new(), - // id_to_def: RwLock::new(HashMap::new()), - // class_names: Default::default(), - // }) as Mutex); - let resolver = Arc::new(Box::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()), @@ -109,7 +112,7 @@ fn test_primitives() { let top_level = Arc::new(TopLevelContext { definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])), - personality_symbol: None + personality_symbol: None, }); let unifier = (unifier.get_shared_unifier(), primitives); @@ -182,8 +185,7 @@ fn test_primitives() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); } @@ -223,6 +225,7 @@ fn test_simple_call() { instance_to_stmt: HashMap::new(), instance_to_symbol: HashMap::new(), resolver: None, + codegen_callback: None, }))); let resolver = Box::new(Resolver { @@ -298,7 +301,7 @@ fn test_simple_call() { let top_level = Arc::new(TopLevelContext { definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])), - personality_symbol: None + personality_symbol: None, }); let unifier = (unifier.get_shared_unifier(), primitives); @@ -347,8 +350,7 @@ fn test_simple_call() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index da533e42b..2fc102e7c 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -100,6 +100,7 @@ impl TopLevelComposer { instance_to_symbol: [("".into(), name.into())].iter().cloned().collect(), var_id: Default::default(), resolver: None, + codegen_callback: None, })), None, )); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 411ec670e..c65f74d83 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -122,6 +122,7 @@ impl TopLevelComposer { instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver, + codegen_callback: None, } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index a415e50a0..3f9ba3daf 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -7,6 +7,7 @@ use std::{ sync::Arc, }; +use super::codegen::CodeGenContext; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; use crate::{ @@ -16,6 +17,7 @@ use crate::{ use itertools::{izip, Itertools}; use parking_lot::RwLock; use rustpython_parser::ast::{self, Stmt, StrRef}; +use inkwell::values::BasicValueEnum; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] pub struct DefinitionId(pub usize); @@ -28,6 +30,43 @@ use type_annotation::*; #[cfg(test)] mod test; +type GenCallCallback = Box< + dyn for<'ctx, 'a> Fn( + &mut CodeGenContext<'ctx, 'a>, + Option<(Type, BasicValueEnum)>, + (&FunSignature, DefinitionId), + Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> Option> + + Send + + Sync, +>; + +pub struct GenCall { + fp: GenCallCallback, +} + +impl GenCall { + pub fn new(fp: GenCallCallback) -> GenCall { + GenCall { fp } + } + + pub fn run<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> Option> { + (self.fp)(ctx, obj, fun, args) + } +} + +impl Debug for GenCall { + fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} + #[derive(Clone, Debug)] pub struct FunInstance { pub body: Arc>>>, @@ -78,6 +117,8 @@ pub enum TopLevelDef { instance_to_stmt: HashMap, // symbol resolver of the module defined the class resolver: Option>>, + // custom codegen callback + codegen_callback: Option> }, } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index f36a6f6e0..ed86b1ab6 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -6,7 +6,7 @@ use rustpython_parser::parser; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use nac3core::{ - codegen::{CodeGenTask, WithCall, GenCall, WorkerRegistry}, + codegen::{CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, @@ -144,19 +144,9 @@ fn main() { // println!("IR:\n{}", module.print_to_string().to_str().unwrap()); }))); - let external_codegen = Arc::new(GenCall::new(Box::new( - // example implementation that does sitofp: - // note that a proper implementation may want to check the definition ID - // |ctx, _, _, args| { - // let arg = args[0].1.into_int_value(); - // let val = ctx.builder.build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp").into(); - // Some(val) - // } - |_, _, _, _| unimplemented!() - ), Default::default())); let threads: Vec = (0..threads).map(|i| format!("module{}", i)).collect(); let threads: Vec<_> = threads.iter().map(|s| s.as_str()).collect(); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles);