From 298e0b19086184c61d84d757c6642fbba8154199 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Fri, 6 Sep 2024 19:14:06 +0800 Subject: [PATCH] compile and run code --- Cargo.lock | 1 + Cargo.toml | 1 + src/basic_symbol_resolver.rs | 93 +++++++++++++++++ src/main.rs | 197 +++++++++++++++++++++++++++++++++-- 4 files changed, 286 insertions(+), 6 deletions(-) create mode 100644 src/basic_symbol_resolver.rs diff --git a/Cargo.lock b/Cargo.lock index 0f347b1..8167e9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,6 +687,7 @@ dependencies = [ "eframe", "egui_extras", "nac3core", + "parking_lot", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index fd45ac8..a6ac262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] eframe = "0.28" egui_extras = { version = "0.28", features = ["syntect"]} +parking_lot = "0.12" [dependencies.nac3core] git = "https://git.m-labs.hk/M-Labs/nac3" diff --git a/src/basic_symbol_resolver.rs b/src/basic_symbol_resolver.rs new file mode 100644 index 0000000..8481e82 --- /dev/null +++ b/src/basic_symbol_resolver.rs @@ -0,0 +1,93 @@ +use nac3core::nac3parser::ast::{self, StrRef}; +use nac3core::{ + codegen::CodeGenContext, + symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum}, + toplevel::{DefinitionId, TopLevelDef}, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{Type, Unifier}, + }, +}; +use parking_lot::{Mutex, RwLock}; +use std::collections::HashSet; +use std::{collections::HashMap, sync::Arc}; + +pub struct ResolverInternal { + pub id_to_type: Mutex>, + pub id_to_def: Mutex>, + pub module_globals: Mutex>, + pub str_store: Mutex>, +} + +impl ResolverInternal { + pub fn add_id_def(&self, id: StrRef, def: DefinitionId) { + self.id_to_def.lock().insert(id, def); + } + + pub fn add_id_type(&self, id: StrRef, ty: Type) { + self.id_to_type.lock().insert(id, ty); + } + + pub fn add_module_global(&self, id: StrRef, val: SymbolValue) { + self.module_globals.lock().insert(id, val); + } +} + +pub struct Resolver(pub Arc); + +impl SymbolResolver for Resolver { + fn get_default_param_value(&self, expr: &ast::Expr) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => self.0.module_globals.lock().get(id).cloned(), + _ => unimplemented!("other type of expr not supported at {}", expr.location), + } + } + + fn get_symbol_type( + &self, + _: &mut Unifier, + _: &[Arc>], + _: &PrimitiveStore, + str: StrRef, + ) -> Result { + self.0 + .id_to_type + .lock() + .get(&str) + .copied() + .ok_or(format!("cannot get type of {str}")) + } + + fn get_symbol_value<'ctx>( + &self, + _: StrRef, + _: &mut CodeGenContext<'ctx, '_>, + ) -> Option> { + unimplemented!() + } + + fn get_identifier_def(&self, id: StrRef) -> Result> { + self.0 + .id_to_def + .lock() + .get(&id) + .copied() + .ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")])) + } + + fn get_string_id(&self, s: &str) -> i32 { + let mut str_store = self.0.str_store.lock(); + if let Some(id) = str_store.get(s) { + *id + } else { + let id = i32::try_from(str_store.len()) + .expect("Symbol resolver string store size exceeds max capacity (i32::MAX)"); + str_store.insert(s.to_string(), id); + id + } + } + + fn get_exception_id(&self, _: usize) -> usize { + unimplemented!() + } +} diff --git a/src/main.rs b/src/main.rs index c11069d..b39aa0d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,196 @@ -use eframe::egui; -use nac3core::nac3parser::parser; +use std::collections::HashMap; +use std::num::NonZeroUsize; +use std::sync::Arc; +use parking_lot::Mutex; + +use eframe::egui; + +use nac3core::codegen; +use nac3core::inkwell; +use nac3core::nac3parser; +use nac3core::toplevel; +use nac3core::toplevel::composer; +use nac3core::typecheck::{type_inferencer, typedef}; + +mod basic_symbol_resolver; +use basic_symbol_resolver::{Resolver, ResolverInternal}; + +fn run(code: &String) { + let llvm_options = codegen::CodeGenLLVMOptions { + opt_level: inkwell::OptimizationLevel::Default, + target: codegen::CodeGenTargetMachineOptions::from_host(), + }; + let context = inkwell::context::Context::create(); + let size_t = context + .ptr_sized_int_type( + &llvm_options + .target + .create_target_machine(llvm_options.opt_level) + .map(|tm| tm.get_target_data()) + .unwrap(), + None, + ) + .get_bit_width(); + let primitive: type_inferencer::PrimitiveStore = + composer::TopLevelComposer::make_primitives(size_t).0; + let (mut composer, builtins_def, builtins_ty) = composer::TopLevelComposer::new( + vec![], + vec![], + composer::ComposerConfig::default(), + size_t, + ); + let internal_resolver: Arc = ResolverInternal { + id_to_type: builtins_ty.into(), + id_to_def: builtins_def.into(), + module_globals: Mutex::default(), + str_store: Mutex::default(), + } + .into(); + let resolver = Arc::new(Resolver(internal_resolver.clone())) + as Arc; + let irrt = codegen::irrt::load_irrt(&context, resolver.as_ref()); + + let parser_result = + match nac3parser::parser::parse_program(code.as_str(), String::from("cell1").into()) { + Ok(parser_result) => parser_result, + Err(err) => { + eprintln!("parse error: {}", err); + return; + } + }; + for stmt in parser_result { + match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) { + Ok((name, def_id, ty)) => { + internal_resolver.add_id_def(name, def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(name, ty); + } + } + Err(err) => { + eprintln!("composer error: {}", err); + return; + } + } + } + + let signature = typedef::FunSignature { + args: vec![], + ret: primitive.int32, + vars: typedef::VarMap::new(), + }; + let mut store = codegen::concrete_type::ConcreteTypeStore::new(); + let mut cache = HashMap::new(); + let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache); + let signature = store.add_cty(signature); + + if let Err(errors) = composer.start_analysis(true) { + let error_count = errors.len(); + eprintln!("{error_count} error(s) occurred during top level analysis."); + + for (error_i, error) in errors.iter().enumerate() { + let error_num = error_i + 1; + eprintln!("=========== ERROR {error_num}/{error_count} ============"); + eprintln!("{error}"); + } + eprintln!("=================================="); + return; + } + + let top_level = Arc::new(composer.make_top_level_context()); + + let run_id_def = match resolver.get_identifier_def("run".into()) { + Ok(run_id_def) => run_id_def, + Err(_) => { + eprintln!("no run() entry point"); + return; + } + }; + let instance = { + let defs = top_level.definitions.read(); + let mut instance = defs[run_id_def.0].write(); + let toplevel::TopLevelDef::Function { + instance_to_stmt, + instance_to_symbol, + .. + } = &mut *instance + else { + unreachable!() + }; + instance_to_symbol.insert(String::new(), "run".to_string()); + instance_to_stmt[""].clone() + }; + + let task = codegen::CodeGenTask { + subst: Vec::default(), + symbol_name: "run".to_string(), + body: instance.body, + signature, + resolver, + store, + unifier_index: instance.unifier_id, + calls: instance.calls, + id: 0, + }; + + let nthreads = if inkwell::support::is_multithreaded() { + std::thread::available_parallelism() + .map(NonZeroUsize::get) + .unwrap_or(1usize) + } else { + 1 + }; + + let membuffers: Arc>>> = Arc::default(); + let membuffer = membuffers.clone(); + + let f = Arc::new(codegen::WithCall::new(Box::new(move |module| { + let buffer = module.write_bitcode_to_memory(); + let buffer = buffer.as_slice().into(); + membuffer.lock().push(buffer); + }))); + let threads = (0..nthreads) + .map(|i| { + Box::new(codegen::DefaultCodeGenerator::new( + format!("module{i}"), + size_t, + )) + }) + .collect(); + let (registry, handles) = + codegen::WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); + registry.add_task(task); + registry.wait_tasks_complete(handles); + + // Link all modules together into `main` + let buffers = membuffers.lock(); + let main = context + .create_module_from_ir( + inkwell::memory_buffer::MemoryBuffer::create_from_memory_range(&buffers[0], "main"), + ) + .unwrap(); + for buffer in buffers.iter().skip(1) { + let other = context + .create_module_from_ir( + inkwell::memory_buffer::MemoryBuffer::create_from_memory_range(buffer, "main"), + ) + .unwrap(); + main.link_in_module(other).unwrap(); + } + main.link_in_module(irrt).unwrap(); + + let execution_engine = main + .create_jit_execution_engine(llvm_options.opt_level) + .unwrap(); + type Run = unsafe extern "C" fn() -> i32; + let run: inkwell::execution_engine::JitFunction = + unsafe { execution_engine.get_function("run").unwrap() }; + println!("{}", unsafe { run.call() }); +} fn main() -> eframe::Result { + inkwell::targets::Target::initialize_all(&inkwell::targets::InitializationConfig::default()); + let options = eframe::NativeOptions { viewport: egui::ViewportBuilder::default().with_inner_size([1024.0, 768.0]), ..Default::default() @@ -13,10 +201,7 @@ fn main() -> eframe::Result { eframe::run_simple_native("Cells", options, move |ctx, _frame| { let submit_key = egui::KeyboardShortcut::new(egui::Modifiers::CTRL, egui::Key::Enter); if ctx.input_mut(|i| i.consume_shortcut(&submit_key)) { - match parser::parse_program(code.as_str(), String::from("cell1").into()) { - Ok(parser_result) => println!("{:?}", parser_result), - Err(err) => println!("parse error: {}", err), - } + run(&code); } egui::CentralPanel::default().show(ctx, |ui| {