forked from M-Labs/nac3
1
0
Fork 0

nac3standalone: compile multiple functions

This commit is contained in:
pca006132 2021-08-27 16:25:59 +08:00
parent 72aebed559
commit 2223c86d9b
3 changed files with 204 additions and 79 deletions

View File

@ -304,7 +304,11 @@ pub fn gen_func<'ctx>(
.fn_type(&params, false) .fn_type(&params, false)
}; };
let fn_val = module.add_function(&task.symbol_name, fn_type, None); let symbol = &task.symbol_name;
let fn_val = module
.get_function(symbol)
.unwrap_or_else(|| module.add_function(symbol, fn_type, None));
let init_bb = context.append_basic_block(fn_val, "init"); let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb); builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body"); let body_bb = context.append_basic_block(fn_val, "body");

View File

@ -22,6 +22,7 @@ mod type_annotation;
use type_annotation::*; use type_annotation::*;
mod helper; mod helper;
#[derive(Clone)]
pub struct FunInstance { pub struct FunInstance {
pub body: Vec<Stmt<Option<Type>>>, pub body: Vec<Stmt<Option<Type>>>,
pub calls: HashMap<CodeLocation, CallId>, pub calls: HashMap<CodeLocation, CallId>,

View File

@ -1,7 +1,11 @@
use std::fs;
use std::time::SystemTime; use std::time::SystemTime;
use std::{collections::HashSet, fs};
use inkwell::{OptimizationLevel, passes::{PassManager, PassManagerBuilder}, targets::*}; use inkwell::{
passes::{PassManager, PassManagerBuilder},
targets::*,
OptimizationLevel,
};
use parking_lot::RwLock; use parking_lot::RwLock;
use rustpython_parser::{ use rustpython_parser::{
ast::{fold::Fold, StmtKind}, ast::{fold::Fold, StmtKind},
@ -12,7 +16,7 @@ use std::{cell::RefCell, collections::HashMap, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
codegen::{CodeGenTask, WithCall, WorkerRegistry}, codegen::{CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{DefinitionId, TopLevelComposer, TopLevelContext, TopLevelDef}, toplevel::{DefinitionId, FunInstance, TopLevelComposer, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{FunctionData, Inferencer}, type_inferencer::{FunctionData, Inferencer},
typedef::{FunSignature, FuncArg, TypeEnum}, typedef::{FunSignature, FuncArg, TypeEnum},
@ -33,25 +37,6 @@ fn main() {
}; };
let start = SystemTime::now(); let start = SystemTime::now();
let statements = match parser::parse_program(&program) {
Ok(mut ast) => {
let first = ast.remove(0);
if let StmtKind::FunctionDef { name, body, .. } = first.node {
if name != "run" {
panic!("Parse error: expected function \"run\" but got {}", name);
}
body
} else {
panic!(
"Parse error: expected function \"run\" but got {:?}",
first.node
);
}
}
Err(err) => {
panic!("Parse error: {}", err);
}
};
let composer = TopLevelComposer::new(); let composer = TopLevelComposer::new();
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
@ -67,7 +52,7 @@ fn main() {
ret: primitives.none, ret: primitives.none,
vars: HashMap::new(), vars: HashMap::new(),
}))); })));
let def_id = top_level.definitions.read().len(); let output_id = top_level.definitions.read().len();
top_level top_level
.definitions .definitions
.write() .write()
@ -83,59 +68,179 @@ fn main() {
resolver: None, resolver: None,
}))); })));
// dummy resolver...
let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver { let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver {
id_to_type: [("output".into(), output_fun)].iter().cloned().collect(), id_to_type: HashMap::new(),
id_to_def: [("output".into(), DefinitionId(def_id))] id_to_def: HashMap::new(),
.iter()
.cloned()
.collect(),
class_names: Default::default(), class_names: Default::default(),
}) as Box<dyn SymbolResolver + Send + Sync>); }) as Box<dyn SymbolResolver + Send + Sync>);
let mut functions = HashMap::new();
let threads = ["test"]; for stmt in parser::parse_program(&program).unwrap().into_iter() {
let signature = FunSignature { if let StmtKind::FunctionDef {
args: vec![], name,
ret: primitives.int32, body,
vars: HashMap::new(), args,
}; returns,
..
let mut function_data = FunctionData { } = stmt.node
resolver: resolver.clone(), {
bound_variables: Vec::new(), let args = args
return_type: Some(primitives.int32), .args
}; .into_iter()
let mut virtual_checks = Vec::new(); .map(|arg| FuncArg {
let mut calls = HashMap::new(); name: arg.node.arg.to_string(),
let mut inferencer = Inferencer { ty: resolver
top_level: &top_level, .parse_type_annotation(
function_data: &mut function_data, &top_level.definitions.read(),
unifier: &mut unifier, &mut unifier,
variable_mapping: Default::default(), &primitives,
primitives: &primitives, &arg.node
virtual_checks: &mut virtual_checks, .annotation
calls: &mut calls, .expect("expected type annotation in parameters"),
defined_identifiers: Default::default(), )
}; .unwrap(),
default_value: None,
})
.collect();
let ret = returns
.map(|r| {
resolver
.parse_type_annotation(
&top_level.definitions.read(),
&mut unifier,
&primitives,
&r,
)
.unwrap()
})
.unwrap_or(primitives.none);
let signature = FunSignature {
args,
ret,
vars: Default::default(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone())));
let id = top_level.definitions.read().len();
top_level
.definitions
.write()
.push(Arc::new(RwLock::new(TopLevelDef::Function {
name: name.clone(),
signature: fun_ty,
var_id: vec![],
instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(),
resolver: None,
})));
functions.insert(name, (id, body, signature));
} else {
panic!("unsupported statement type");
}
}
let setup_time = SystemTime::now(); let setup_time = SystemTime::now();
println!( println!(
"Setup time: {}ms", "Setup time: {}ms",
setup_time.duration_since(start).unwrap().as_millis() setup_time
.duration_since(start)
.unwrap()
.as_millis()
); );
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut identifiers = ["output".to_string()].iter().cloned().collect(); let mut id_to_def: HashMap<_, _> = functions
if !inferencer .iter()
.check_block(&statements, &mut identifiers) .map(|(k, v)| (k.clone(), DefinitionId(v.0)))
.unwrap() .collect();
{ id_to_def.insert("output".into(), DefinitionId(output_id));
panic!("expected return"); let mut id_to_type: HashMap<_, _> = functions
.iter()
.map(|(k, v)| {
(
k.clone(),
unifier.add_ty(TypeEnum::TFunc(RefCell::new(v.2.clone()))),
)
})
.collect();
id_to_type.insert("output".into(), output_fun);
let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver {
class_names: Default::default(),
id_to_type,
id_to_def,
}) as Box<dyn SymbolResolver + Send + Sync>);
for (_, (id, ast, signature)) in functions.into_iter() {
if let TopLevelDef::Function {
resolver: r,
instance_to_stmt,
..
} = &mut *top_level.definitions.read()[id].write()
{
*r = Some(resolver.clone());
let return_type = if unifier.unioned(primitives.none, signature.ret) {
None
} else {
Some(signature.ret)
};
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type,
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers = HashSet::new();
let mut variable_mapping = HashMap::new();
for arg in signature.args.iter() {
identifiers.insert(arg.name.clone());
variable_mapping.insert(arg.name.clone(), arg.ty);
}
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping,
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
};
let statements = ast
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let returned = inferencer
.check_block(&statements, &mut identifiers)
.unwrap();
if return_type.is_some() && !returned {
panic!("expected return");
}
instance_to_stmt.insert(
"".to_string(),
FunInstance {
body: statements,
unifier_id: 0,
calls,
subst: Default::default(),
},
);
}
} }
let inference_time = SystemTime::now();
println!(
"Type inference time: {}ms",
inference_time
.duration_since(setup_time)
.unwrap()
.as_millis()
);
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take( definitions: Arc::new(RwLock::new(std::mem::take(
&mut *top_level.definitions.write(), &mut *top_level.definitions.write(),
@ -146,22 +251,34 @@ fn main() {
)])), )])),
}); });
let inference_time = SystemTime::now(); let instance = {
println!( let defs = top_level.definitions.read();
"Type inference time: {}ms", let mut instance = defs[resolver.get_identifier_def("run").unwrap().0].write();
inference_time.duration_since(setup_time).unwrap().as_millis() if let TopLevelDef::Function {
); instance_to_stmt,
instance_to_symbol,
let unifier = (unifier.get_shared_unifier(), primitives); ..
} = &mut *instance
{
instance_to_symbol.insert("".to_string(), "run".to_string());
instance_to_stmt[""].clone()
} else {
unreachable!()
}
};
let signature = FunSignature {
args: vec![],
ret: primitives.int32,
vars: HashMap::new(),
};
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Default::default(),
symbol_name: "run".to_string(), symbol_name: "run".to_string(),
body: statements, body: instance.body,
resolver,
unifier,
calls,
signature, signature,
resolver,
unifier: top_level.unifiers.read()[instance.unifier_id].clone(),
calls: instance.calls,
}; };
let f = Arc::new(WithCall::new(Box::new(move |module| { let f = Arc::new(WithCall::new(Box::new(move |module| {
let codegen_time = SystemTime::now(); let codegen_time = SystemTime::now();
@ -172,7 +289,6 @@ fn main() {
.unwrap() .unwrap()
.as_millis() .as_millis()
); );
let builder = PassManagerBuilder::create(); let builder = PassManagerBuilder::create();
builder.set_optimization_level(OptimizationLevel::Aggressive); builder.set_optimization_level(OptimizationLevel::Aggressive);
let passes = PassManager::create(()); let passes = PassManager::create(());
@ -195,6 +311,7 @@ fn main() {
target_machine target_machine
.write_to_file(module, FileType::Object, Path::new("mandelbrot.o")) .write_to_file(module, FileType::Object, Path::new("mandelbrot.o"))
.expect("couldn't write module to file"); .expect("couldn't write module to file");
println!( println!(
"LLVM time: {}ms", "LLVM time: {}ms",
SystemTime::now() SystemTime::now()
@ -202,9 +319,12 @@ fn main() {
.unwrap() .unwrap()
.as_millis() .as_millis()
); );
println!("IR:\n{}", module.print_to_string().to_str().unwrap());
}))); })));
let threads = ["test"];
let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
println!("object file is in mandelbrot.o") println!("object file is in mandelbrot.o");
} }