diff --git a/nac3standalone/demo/run_demo.sh b/nac3standalone/demo/run_demo.sh index 5151e93..f4185e7 100755 --- a/nac3standalone/demo/run_demo.sh +++ b/nac3standalone/demo/run_demo.sh @@ -9,5 +9,5 @@ fi rm -f *.o ../../target/release/nac3standalone $1 -clang -Wall -O2 -o $1.elf demo.c module*.o -lm +clang -Wall -O2 -o $1.elf demo.c module.o -lm ./$1.elf diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 05f131d..154a10e 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,16 +1,16 @@ use inkwell::{ passes::{PassManager, PassManagerBuilder}, targets::*, - OptimizationLevel, + OptimizationLevel, memory_buffer::MemoryBuffer, }; use std::{borrow::Borrow, collections::HashMap, env, fs, path::Path, sync::Arc, time::SystemTime}; -use parking_lot::RwLock; +use parking_lot::{RwLock, Mutex}; use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser}; use nac3core::{ codegen::{ concrete_type::ConcreteTypeStore, CodeGenTask, DefaultCodeGenerator, WithCall, - WorkerRegistry, + WorkerRegistry, irrt::load_irrt, }, symbol_resolver::SymbolResolver, toplevel::{ @@ -264,36 +264,14 @@ fn main() { calls: instance.calls, id: 0, }; + + let membuffers: Arc>>> = Default::default(); + let membuffer = membuffers.clone(); + let f = Arc::new(WithCall::new(Box::new(move |module| { - let builder = PassManagerBuilder::create(); - builder.set_optimization_level(OptimizationLevel::Aggressive); - let passes = PassManager::create(()); - builder.set_inliner_with_threshold(255); - builder.populate_module_pass_manager(&passes); - passes.run_on(module); - - let triple = TargetMachine::get_default_triple(); - let target = - Target::from_triple(&triple).expect("couldn't create target from target triple"); - let target_machine = target - .create_target_machine( - &triple, - "", - "", - OptimizationLevel::Default, - RelocMode::Default, - CodeModel::Default, - ) - .expect("couldn't create target machine"); - target_machine - .write_to_file( - module, - FileType::Object, - Path::new(&format!("{}.o", module.get_name().to_str().unwrap())), - ) - .expect("couldn't write module to file"); - - // println!("IR:\n{}", module.print_to_string().to_str().unwrap()); + let buffer = module.write_bitcode_to_memory(); + let buffer = buffer.as_slice().into(); + membuffer.lock().push(buffer); }))); let threads = (0..threads) .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{}", i), 64))) @@ -302,6 +280,56 @@ fn main() { registry.add_task(task); registry.wait_tasks_complete(handles); + let buffers = membuffers.lock(); + let context = inkwell::context::Context::create(); + let main = context + .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) + .unwrap(); + for buffer in buffers.iter().skip(1) { + let other = context + .create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main")) + .unwrap(); + + main.link_in_module(other).unwrap(); + } + main.link_in_module(load_irrt(&context)).unwrap(); + + let mut function_iter = main.get_first_function(); + while let Some(func) = function_iter { + if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" { + func.set_linkage(inkwell::module::Linkage::Private); + } + function_iter = func.get_next_function(); + } + + let builder = PassManagerBuilder::create(); + builder.set_optimization_level(OptimizationLevel::Aggressive); + let passes = PassManager::create(()); + builder.set_inliner_with_threshold(255); + builder.populate_module_pass_manager(&passes); + passes.run_on(&main); + + let triple = TargetMachine::get_default_triple(); + let target = + Target::from_triple(&triple).expect("couldn't create target from target triple"); + let target_machine = target + .create_target_machine( + &triple, + "", + "", + OptimizationLevel::Default, + RelocMode::Default, + CodeModel::Default, + ) + .expect("couldn't create target machine"); + target_machine + .write_to_file( + &main, + FileType::Object, + Path::new("module.o"), + ) + .expect("couldn't write module to file"); + let final_time = SystemTime::now(); println!( "codegen time (including LLVM): {}ms",