diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 5569b309..0ac2f856 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -108,6 +108,10 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { fun: (&FunSignature, DefinitionId), params: Vec<(Option, BasicValueEnum<'ctx>)>, ) -> Option> { + if self.external_codegen.def_list.contains(&fun.1) { + let external_codegen = self.external_codegen.clone(); + return external_codegen.run(self, obj, fun, params) + } let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); let mut task = None; diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index e01c130b..2befa6c9 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,6 +1,6 @@ use crate::{ symbol_resolver::SymbolResolver, - toplevel::{TopLevelContext, TopLevelDef}, + toplevel::{DefinitionId, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FunSignature, SharedUnifier, Type, TypeEnum, Unifier}, @@ -13,13 +13,13 @@ use inkwell::{ context::Context, module::Module, types::{BasicType, BasicTypeEnum}, - values::PointerValue, + values::{BasicValueEnum, PointerValue}, AddressSpace, }; use itertools::Itertools; use parking_lot::{Condvar, Mutex}; use rustpython_parser::ast::{Stmt, StrRef}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -49,6 +49,7 @@ pub struct CodeGenContext<'ctx, 'a> { // where continue and break should go to respectively // the first one is the test_bb, and the second one is bb after the loop pub loop_bb: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, + pub external_codegen: Arc, } type Fp = Box; @@ -67,6 +68,42 @@ impl WithCall { } } +type GenCallCallback = Box< + dyn for<'ctx, 'a> Fn( + &mut CodeGenContext<'ctx, 'a>, + Option<(Type, BasicValueEnum)>, + (&FunSignature, DefinitionId), + Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> Option> + + Send + + Sync, +>; + +pub struct GenCall { + def_list: HashSet, + fp: GenCallCallback, +} + +impl GenCall { + pub fn new(fp: GenCallCallback, def_list: HashSet) -> GenCall { + GenCall { def_list, fp } + } + + pub fn run<'ctx, 'a>( + &self, + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, BasicValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, BasicValueEnum<'ctx>)>, + ) -> Option> { + (self.fp)(ctx, obj, fun, args) + } + + pub fn need_external_codegen(&self, id: DefinitionId) -> bool { + self.def_list.contains(&id) + } +} + pub struct WorkerRegistry { sender: Arc>>, receiver: Arc>>, @@ -81,6 +118,7 @@ impl WorkerRegistry { names: &[&str], top_level_ctx: Arc, f: Arc, + external_codegen: Arc, ) -> (Arc, Vec>) { let (sender, receiver) = unbounded(); let task_count = Mutex::new(0); @@ -102,8 +140,9 @@ impl WorkerRegistry { let registry2 = registry.clone(); let name = name.to_string(); let f = f.clone(); + let external_codegen = external_codegen.clone(); let handle = thread::spawn(move || { - registry.worker_thread(name, top_level_ctx, f); + registry.worker_thread(name, top_level_ctx, f, external_codegen); }); let handle = thread::spawn(move || { if let Err(e) = handle.join() { @@ -161,13 +200,22 @@ impl WorkerRegistry { module_name: String, top_level_ctx: Arc, f: Arc, + external_codegen: Arc, ) { let context = Context::create(); let mut builder = context.create_builder(); let mut module = context.create_module(&module_name); while let Some(task) = self.receiver.recv().unwrap() { - let result = gen_func(&context, self, builder, module, task, top_level_ctx.clone()); + let result = gen_func( + &context, + self, + builder, + module, + task, + top_level_ctx.clone(), + external_codegen.clone(), + ); builder = result.0; module = result.1; *self.task_count.lock() -= 1; @@ -250,6 +298,7 @@ pub fn gen_func<'ctx>( module: Module<'ctx>, task: CodeGenTask, top_level_ctx: Arc, + external_codegen: Arc, ) -> (Builder<'ctx>, Module<'ctx>) { // unwrap_or(0) is for unit tests without using rayon let (mut unifier, primitives) = { @@ -347,6 +396,7 @@ pub fn gen_func<'ctx>( builder, module, unifier, + external_codegen, }; let mut returned = false; diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index be4f5d94..ad524a1c 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,15 +1,9 @@ -use crate::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, - location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, - toplevel::{ +use crate::{codegen::{CodeGenTask, GenCall, WithCall, WorkerRegistry}, location::Location, symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, - }, - typecheck::{ + }, typecheck::{ type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, - }, -}; + }}; use indoc::indoc; use parking_lot::RwLock; use rustpython_parser::{ast::{StrRef, fold::Fold}, parser::parse_program}; @@ -188,7 +182,8 @@ fn test_primitives() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); registry.add_task(task); registry.wait_tasks_complete(handles); } @@ -352,7 +347,8 @@ fn test_simple_call() { .trim(); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); }))); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); registry.add_task(task); registry.wait_tasks_complete(handles); } diff --git a/nac3embedded/src/lib.rs b/nac3embedded/src/lib.rs index d00c903e..1df6f236 100644 --- a/nac3embedded/src/lib.rs +++ b/nac3embedded/src/lib.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::path::Path; use std::process::Command; @@ -14,7 +14,7 @@ use inkwell::{ use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, + codegen::{CodeGenTask, WithCall, WorkerRegistry, GenCall}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, TopLevelContext, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, @@ -59,10 +59,10 @@ impl Nac3 { Box::new(Resolver(internal_resolver.clone())) as Box ); Nac3 { - primitive: primitive, - internal_resolver: internal_resolver, - resolver: resolver, - composer: composer, + primitive, + internal_resolver, + resolver, + composer, top_level: None } } @@ -168,9 +168,10 @@ impl Nac3 { .write_to_file(module, FileType::Object, Path::new(&format!("{}.o", module.get_name().to_str().unwrap()))) .expect("couldn't write module to file"); }))); + let external_codegen = Arc::new(GenCall::new(Box::new(|_, _, _, _| unimplemented!()), HashSet::new())); let thread_names: Vec = (0..4).map(|i| format!("module{}", i)).collect(); let threads: Vec<_> = thread_names.iter().map(|s| s.as_str()).collect(); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level.clone(), f, external_codegen); registry.add_task(task); registry.wait_tasks_complete(handles); diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index fbf47d3a..f36a6f6e 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,16 +1,12 @@ -use std::fs; use std::env; -use inkwell::{ - passes::{PassManager, PassManagerBuilder}, - targets::*, - OptimizationLevel, -}; +use std::fs; +use inkwell::{OptimizationLevel, passes::{PassManager, PassManagerBuilder}, targets::*}; use nac3core::typecheck::type_inferencer::PrimitiveStore; use rustpython_parser::parser; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use nac3core::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, + codegen::{CodeGenTask, WithCall, GenCall, WorkerRegistry}, symbol_resolver::SymbolResolver, toplevel::{composer::TopLevelComposer, TopLevelDef}, typecheck::typedef::{FunSignature, FuncArg}, @@ -148,9 +144,19 @@ fn main() { // println!("IR:\n{}", module.print_to_string().to_str().unwrap()); }))); + let external_codegen = Arc::new(GenCall::new(Box::new( + // example implementation that does sitofp: + // note that a proper implementation may want to check the definition ID + // |ctx, _, _, args| { + // let arg = args[0].1.into_int_value(); + // let val = ctx.builder.build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp").into(); + // Some(val) + // } + |_, _, _, _| unimplemented!() + ), Default::default())); let threads: Vec = (0..threads).map(|i| format!("module{}", i)).collect(); let threads: Vec<_> = threads.iter().map(|s| s.as_str()).collect(); - let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f, external_codegen); registry.add_task(task); registry.wait_tasks_complete(handles);