From 2223c86d9ba9ed506f6fc7710189f9406a1ddc81 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Fri, 27 Aug 2021 16:25:59 +0800 Subject: [PATCH] nac3standalone: compile multiple functions --- nac3core/src/codegen/mod.rs | 6 +- nac3core/src/toplevel/mod.rs | 1 + nac3standalone/src/main.rs | 276 +++++++++++++++++++++++++---------- 3 files changed, 204 insertions(+), 79 deletions(-) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index e8e8e8bae..ccedbe5de 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -304,7 +304,11 @@ pub fn gen_func<'ctx>( .fn_type(¶ms, false) }; - let fn_val = module.add_function(&task.symbol_name, fn_type, None); + let symbol = &task.symbol_name; + let fn_val = module + .get_function(symbol) + .unwrap_or_else(|| module.add_function(symbol, fn_type, None)); + let init_bb = context.append_basic_block(fn_val, "init"); builder.position_at_end(init_bb); let body_bb = context.append_basic_block(fn_val, "body"); diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 8ccc5452e..4be69600c 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -22,6 +22,7 @@ mod type_annotation; use type_annotation::*; mod helper; +#[derive(Clone)] pub struct FunInstance { pub body: Vec>>, pub calls: HashMap, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 467530602..3378d3862 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,7 +1,11 @@ -use std::fs; use std::time::SystemTime; +use std::{collections::HashSet, fs}; -use inkwell::{OptimizationLevel, passes::{PassManager, PassManagerBuilder}, targets::*}; +use inkwell::{ + passes::{PassManager, PassManagerBuilder}, + targets::*, + OptimizationLevel, +}; use parking_lot::RwLock; use rustpython_parser::{ ast::{fold::Fold, StmtKind}, @@ -12,7 +16,7 @@ use std::{cell::RefCell, collections::HashMap, path::Path, sync::Arc}; use nac3core::{ codegen::{CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, - toplevel::{DefinitionId, TopLevelComposer, TopLevelContext, TopLevelDef}, + toplevel::{DefinitionId, FunInstance, TopLevelComposer, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{FunctionData, Inferencer}, typedef::{FunSignature, FuncArg, TypeEnum}, @@ -33,25 +37,6 @@ fn main() { }; let start = SystemTime::now(); - let statements = match parser::parse_program(&program) { - Ok(mut ast) => { - let first = ast.remove(0); - if let StmtKind::FunctionDef { name, body, .. } = first.node { - if name != "run" { - panic!("Parse error: expected function \"run\" but got {}", name); - } - body - } else { - panic!( - "Parse error: expected function \"run\" but got {:?}", - first.node - ); - } - } - Err(err) => { - panic!("Parse error: {}", err); - } - }; let composer = TopLevelComposer::new(); let mut unifier = composer.unifier.clone(); @@ -67,7 +52,7 @@ fn main() { ret: primitives.none, vars: HashMap::new(), }))); - let def_id = top_level.definitions.read().len(); + let output_id = top_level.definitions.read().len(); top_level .definitions .write() @@ -83,59 +68,179 @@ fn main() { resolver: None, }))); + // dummy resolver... let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver { - id_to_type: [("output".into(), output_fun)].iter().cloned().collect(), - id_to_def: [("output".into(), DefinitionId(def_id))] - .iter() - .cloned() - .collect(), + id_to_type: HashMap::new(), + id_to_def: HashMap::new(), class_names: Default::default(), }) as Box); + let mut functions = HashMap::new(); - let threads = ["test"]; - let signature = FunSignature { - args: vec![], - ret: primitives.int32, - vars: HashMap::new(), - }; - - let mut function_data = FunctionData { - resolver: resolver.clone(), - bound_variables: Vec::new(), - return_type: Some(primitives.int32), - }; - let mut virtual_checks = Vec::new(); - let mut calls = HashMap::new(); - let mut inferencer = Inferencer { - top_level: &top_level, - function_data: &mut function_data, - unifier: &mut unifier, - variable_mapping: Default::default(), - primitives: &primitives, - virtual_checks: &mut virtual_checks, - calls: &mut calls, - defined_identifiers: Default::default(), - }; + for stmt in parser::parse_program(&program).unwrap().into_iter() { + if let StmtKind::FunctionDef { + name, + body, + args, + returns, + .. + } = stmt.node + { + let args = args + .args + .into_iter() + .map(|arg| FuncArg { + name: arg.node.arg.to_string(), + ty: resolver + .parse_type_annotation( + &top_level.definitions.read(), + &mut unifier, + &primitives, + &arg.node + .annotation + .expect("expected type annotation in parameters"), + ) + .unwrap(), + default_value: None, + }) + .collect(); + let ret = returns + .map(|r| { + resolver + .parse_type_annotation( + &top_level.definitions.read(), + &mut unifier, + &primitives, + &r, + ) + .unwrap() + }) + .unwrap_or(primitives.none); + let signature = FunSignature { + args, + ret, + vars: Default::default(), + }; + let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone()))); + let id = top_level.definitions.read().len(); + top_level + .definitions + .write() + .push(Arc::new(RwLock::new(TopLevelDef::Function { + name: name.clone(), + signature: fun_ty, + var_id: vec![], + instance_to_stmt: HashMap::new(), + instance_to_symbol: HashMap::new(), + resolver: None, + }))); + functions.insert(name, (id, body, signature)); + } else { + panic!("unsupported statement type"); + } + } let setup_time = SystemTime::now(); println!( "Setup time: {}ms", - setup_time.duration_since(start).unwrap().as_millis() + setup_time + .duration_since(start) + .unwrap() + .as_millis() ); - let statements = statements - .into_iter() - .map(|v| inferencer.fold_stmt(v)) - .collect::, _>>() - .unwrap(); - let mut identifiers = ["output".to_string()].iter().cloned().collect(); - if !inferencer - .check_block(&statements, &mut identifiers) - .unwrap() - { - panic!("expected return"); + let mut id_to_def: HashMap<_, _> = functions + .iter() + .map(|(k, v)| (k.clone(), DefinitionId(v.0))) + .collect(); + id_to_def.insert("output".into(), DefinitionId(output_id)); + let mut id_to_type: HashMap<_, _> = functions + .iter() + .map(|(k, v)| { + ( + k.clone(), + unifier.add_ty(TypeEnum::TFunc(RefCell::new(v.2.clone()))), + ) + }) + .collect(); + id_to_type.insert("output".into(), output_fun); + + let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver { + class_names: Default::default(), + id_to_type, + id_to_def, + }) as Box); + + for (_, (id, ast, signature)) in functions.into_iter() { + if let TopLevelDef::Function { + resolver: r, + instance_to_stmt, + .. + } = &mut *top_level.definitions.read()[id].write() + { + *r = Some(resolver.clone()); + + let return_type = if unifier.unioned(primitives.none, signature.ret) { + None + } else { + Some(signature.ret) + }; + let mut function_data = FunctionData { + resolver: resolver.clone(), + bound_variables: Vec::new(), + return_type, + }; + let mut virtual_checks = Vec::new(); + let mut calls = HashMap::new(); + let mut identifiers = HashSet::new(); + let mut variable_mapping = HashMap::new(); + for arg in signature.args.iter() { + identifiers.insert(arg.name.clone()); + variable_mapping.insert(arg.name.clone(), arg.ty); + } + let mut inferencer = Inferencer { + top_level: &top_level, + function_data: &mut function_data, + unifier: &mut unifier, + variable_mapping, + primitives: &primitives, + virtual_checks: &mut virtual_checks, + calls: &mut calls, + defined_identifiers: identifiers.clone(), + }; + let statements = ast + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + let returned = inferencer + .check_block(&statements, &mut identifiers) + .unwrap(); + if return_type.is_some() && !returned { + panic!("expected return"); + } + + instance_to_stmt.insert( + "".to_string(), + FunInstance { + body: statements, + unifier_id: 0, + calls, + subst: Default::default(), + }, + ); + } } + let inference_time = SystemTime::now(); + println!( + "Type inference time: {}ms", + inference_time + .duration_since(setup_time) + .unwrap() + .as_millis() + ); + let top_level = Arc::new(TopLevelContext { definitions: Arc::new(RwLock::new(std::mem::take( &mut *top_level.definitions.write(), @@ -146,22 +251,34 @@ fn main() { )])), }); - let inference_time = SystemTime::now(); - println!( - "Type inference time: {}ms", - inference_time.duration_since(setup_time).unwrap().as_millis() - ); - - let unifier = (unifier.get_shared_unifier(), primitives); - + let instance = { + let defs = top_level.definitions.read(); + let mut instance = defs[resolver.get_identifier_def("run").unwrap().0].write(); + if let TopLevelDef::Function { + instance_to_stmt, + instance_to_symbol, + .. + } = &mut *instance + { + instance_to_symbol.insert("".to_string(), "run".to_string()); + instance_to_stmt[""].clone() + } else { + unreachable!() + } + }; + let signature = FunSignature { + args: vec![], + ret: primitives.int32, + vars: HashMap::new(), + }; let task = CodeGenTask { subst: Default::default(), symbol_name: "run".to_string(), - body: statements, - resolver, - unifier, - calls, + body: instance.body, signature, + resolver, + unifier: top_level.unifiers.read()[instance.unifier_id].clone(), + calls: instance.calls, }; let f = Arc::new(WithCall::new(Box::new(move |module| { let codegen_time = SystemTime::now(); @@ -172,7 +289,6 @@ fn main() { .unwrap() .as_millis() ); - let builder = PassManagerBuilder::create(); builder.set_optimization_level(OptimizationLevel::Aggressive); let passes = PassManager::create(()); @@ -195,6 +311,7 @@ fn main() { target_machine .write_to_file(module, FileType::Object, Path::new("mandelbrot.o")) .expect("couldn't write module to file"); + println!( "LLVM time: {}ms", SystemTime::now() @@ -202,9 +319,12 @@ fn main() { .unwrap() .as_millis() ); + println!("IR:\n{}", module.print_to_string().to_str().unwrap()); + }))); + let threads = ["test"]; let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); - println!("object file is in mandelbrot.o") + println!("object file is in mandelbrot.o"); }