1
0
forked from M-Labs/nac3

nac3standalone, nac3core: can use top level composer to compile and run mandelbrot

This commit is contained in:
ychenfo 2021-09-19 16:19:16 +08:00
parent 1b0f3d07cc
commit 2b74895b71
5 changed files with 156 additions and 292 deletions

View File

@ -56,7 +56,7 @@ fn test_primitives() {
"};
let statements = parse_program(source).unwrap();
let composer = TopLevelComposer::new();
let composer: TopLevelComposer = Default::default();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
@ -205,7 +205,7 @@ fn test_simple_call() {
"};
let statements_2 = parse_program(source_2).unwrap();
let composer = TopLevelComposer::new();
let composer: TopLevelComposer = Default::default();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());

View File

@ -1,3 +1,5 @@
use std::cell::RefCell;
use rustpython_parser::ast::fold::Fold;
use crate::typecheck::type_inferencer::{FunctionData, Inferencer};
@ -20,22 +22,22 @@ pub struct TopLevelComposer {
pub defined_function_name: HashSet<String>,
// get the class def id of a class method
pub method_class: HashMap<DefinitionId, DefinitionId>,
pub built_in_num: usize,
}
impl Default for TopLevelComposer {
fn default() -> Self {
Self::new()
Self::new(vec![]).0
}
}
impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new() -> Self {
pub fn new(builtins: Vec<(String, FunSignature)>) -> (Self, HashMap<String, DefinitionId>, HashMap<String, Type>) {
let primitives = Self::make_primitives();
TopLevelComposer {
definition_ast_list: {
let mut definition_ast_list = {
let top_level_def_list = vec![
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))),
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))),
@ -45,10 +47,10 @@ impl TopLevelComposer {
];
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
izip!(top_level_def_list, ast_list).collect_vec()
},
primitives_ty: primitives.0,
unifier: primitives.1,
keyword_list: HashSet::from_iter(vec![
};
let primitives_ty = primitives.0;
let mut unifier = primitives.1;
let keyword_list: HashSet<String> = HashSet::from_iter(vec![
"Generic".into(),
"virtual".into(),
"list".into(),
@ -62,12 +64,53 @@ impl TopLevelComposer {
"self".into(),
"Kernel".into(),
"KernelImmutable".into(),
]),
defined_class_method_name: Default::default(),
defined_class_name: Default::default(),
defined_function_name: Default::default(),
method_class: Default::default(),
]);
let mut defined_class_method_name: HashSet<String> = Default::default();
let mut defined_class_name: HashSet<String> = Default::default();
let mut defined_function_name: HashSet<String> = Default::default();
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
let mut built_in_id: HashMap<String, DefinitionId> = Default::default();
let mut built_in_ty: HashMap<String, Type> = Default::default();
for (name, sig) in builtins {
let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig)));
built_in_ty.insert(name.clone(), fun_sig);
built_in_id.insert(name.clone(), DefinitionId(definition_ast_list.len()));
definition_ast_list.push((
Arc::new(RwLock::new(TopLevelDef::Function {
name: name.clone(),
signature: fun_sig,
instance_to_stmt: HashMap::new(),
instance_to_symbol: [("".to_string(), name.clone())]
.iter()
.cloned()
.collect(),
var_id: Default::default(),
resolver: None,
})),
None
));
defined_class_method_name.insert(name.clone());
defined_class_name.insert(name.clone());
defined_function_name.insert(name);
}
(
TopLevelComposer {
built_in_num: definition_ast_list.len(),
definition_ast_list,
primitives_ty,
unifier,
keyword_list,
defined_class_method_name,
defined_class_name,
defined_function_name,
method_class,
},
built_in_id,
built_in_ty,
)
}
pub fn make_top_level_context(&self) -> TopLevelContext {
@ -275,7 +318,7 @@ impl TopLevelComposer {
let primitives_store = &self.primitives_ty;
// skip 5 to skip analyzing the primitives
for (class_def, class_ast) in def_list.iter().skip(5) {
for (class_def, class_ast) in def_list.iter().skip(self.built_in_num) {
// only deal with class def here
let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = {
@ -376,7 +419,7 @@ impl TopLevelComposer {
// first, only push direct parent into the list
// skip 5 to skip analyzing the primitives
for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(5) {
for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.built_in_num) {
let mut class_def = class_def.write();
let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = {
if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } =
@ -440,7 +483,7 @@ impl TopLevelComposer {
// second, get all ancestors
let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = Default::default();
// skip 5 to skip analyzing the primitives
for (class_def, _) in self.definition_ast_list.iter().skip(5) {
for (class_def, _) in self.definition_ast_list.iter().skip(self.built_in_num) {
let class_def = class_def.read();
let (class_ancestors, class_id) = {
if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() {
@ -462,7 +505,7 @@ impl TopLevelComposer {
// insert the ancestors to the def list
// skip 5 to skip analyzing the primitives
for (class_def, _) in self.definition_ast_list.iter_mut().skip(5) {
for (class_def, _) in self.definition_ast_list.iter_mut().skip(self.built_in_num) {
let mut class_def = class_def.write();
let (class_ancestors, class_id, class_type_vars) = {
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } =
@ -495,7 +538,7 @@ impl TopLevelComposer {
let mut type_var_to_concrete_def: HashMap<Type, TypeAnnotation> = HashMap::new();
// skip 5 to skip analyzing the primitives
for (class_def, class_ast) in def_ast_list.iter().skip(5) {
for (class_def, class_ast) in def_ast_list.iter().skip(self.built_in_num) {
if matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
Self::analyze_single_class_methods_fields(
class_def.clone(),
@ -516,7 +559,7 @@ impl TopLevelComposer {
loop {
let mut finished = true;
for (class_def, _) in def_ast_list.iter().skip(5) {
for (class_def, _) in def_ast_list.iter().skip(self.built_in_num) {
let mut class_def = class_def.write();
if let TopLevelDef::Class { ancestors, .. } = class_def.deref() {
// if the length of the ancestor is equal to the current depth
@ -575,7 +618,7 @@ impl TopLevelComposer {
let primitives_store = &self.primitives_ty;
// skip 5 to skip analyzing the primitives
for (function_def, function_ast) in def_list.iter().skip(5) {
for (function_def, function_ast) in def_list.iter().skip(self.built_in_num) {
let mut function_def = function_def.write();
let function_def = function_def.deref_mut();
let function_ast = if let Some(x) = function_ast.as_ref() {
@ -1118,7 +1161,7 @@ impl TopLevelComposer {
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
fn analyze_function_instance(&mut self) -> Result<(), String> {
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() {
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) {
let mut function_def = def.write();
if let TopLevelDef::Function {
instance_to_stmt,

View File

@ -88,7 +88,7 @@ impl SymbolResolver for Resolver {
"register"
)]
fn test_simple_register(source: Vec<&str>) {
let mut composer = TopLevelComposer::new();
let mut composer: TopLevelComposer = Default::default();
for s in source {
let ast = parse_program(s).unwrap();
@ -126,7 +126,7 @@ fn test_simple_register(source: Vec<&str>) {
"function compose"
)]
fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) {
let mut composer = TopLevelComposer::new();
let mut composer: TopLevelComposer = Default::default();
let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Default::default(),
@ -151,7 +151,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
composer.start_analysis(true).unwrap();
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() {
let def = &*def.read();
if let TopLevelDef::Function { signature, name, .. } = def {
let ty_str =
@ -770,7 +770,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
)]
fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
let print = false;
let mut composer = TopLevelComposer::new();
let mut composer: TopLevelComposer = Default::default();
let internal_resolver = make_internal_resolver_with_tvar(
vec![
@ -816,7 +816,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
}
} else {
// skip 5 to skip primitives
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() {
let def = &*def.read();
if print {
@ -942,7 +942,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
)]
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let print = true;
let mut composer = TopLevelComposer::new();
let mut composer: TopLevelComposer = Default::default();
let internal_resolver = make_internal_resolver_with_tvar(
vec![
@ -989,7 +989,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
} else {
// skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier};
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() {
let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {

View File

@ -7,18 +7,34 @@ use nac3core::{
typedef::{Type, Unifier},
},
};
use std::collections::HashMap;
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
#[derive(Clone)]
pub struct Resolver {
pub id_to_type: HashMap<String, Type>,
pub id_to_def: HashMap<String, DefinitionId>,
pub class_names: HashMap<String, Type>,
pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<String, Type>>,
pub id_to_def: Mutex<HashMap<String, DefinitionId>>,
pub class_names: Mutex<HashMap<String, Type>>,
}
impl ResolverInternal {
pub fn add_id_def(&self, id: String, def: DefinitionId) {
self.id_to_def.lock().insert(id, def);
}
pub fn add_id_type(&self, id: String, ty: Type) {
self.id_to_type.lock().insert(id, ty);
}
}
pub struct Resolver(pub Arc<ResolverInternal>);
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
let ret = self.0.id_to_type.lock().get(str).cloned();
if ret.is_none() {
// println!("unknown here resolver {}", str);
}
ret
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
@ -30,6 +46,6 @@ impl SymbolResolver for Resolver {
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned()
self.0.id_to_def.lock().get(id).cloned()
}
}

View File

@ -6,24 +6,20 @@ use inkwell::{
targets::*,
OptimizationLevel,
};
use nac3core::typecheck::type_inferencer::PrimitiveStore;
use parking_lot::RwLock;
use rustpython_parser::{
ast::{fold::Fold, StmtKind},
parser,
};
use std::{cell::RefCell, collections::HashMap, path::Path, sync::Arc};
use rustpython_parser::parser;
use std::{collections::HashMap, path::Path, sync::Arc};
use nac3core::{
codegen::{CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver,
toplevel::{DefinitionId, FunInstance, composer::TopLevelComposer, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{FunctionData, Inferencer},
typedef::{FunSignature, FuncArg, TypeEnum},
},
toplevel::{composer::TopLevelComposer, TopLevelDef},
typecheck::typedef::{FunSignature, FuncArg},
};
mod basic_symbol_resolver;
use basic_symbol_resolver::*;
fn main() {
Target::initialize_all(&InitializationConfig::default());
@ -36,220 +32,43 @@ fn main() {
}
};
let start = SystemTime::now();
let composer = TopLevelComposer::new();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let output_fun = unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature {
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
let (mut composer, builtins_def, builtins_ty) = TopLevelComposer::new(vec![
("output".into(), FunSignature {
args: vec![FuncArg {
name: "c".into(),
ty: primitives.int32,
ty: primitive.int32,
default_value: None,
}],
ret: primitives.none,
ret: primitive.none,
vars: HashMap::new(),
})));
let output_id = top_level.definitions.read().len();
top_level
.definitions
.write()
.push(Arc::new(RwLock::new(TopLevelDef::Function {
name: "output".into(),
signature: output_fun,
instance_to_stmt: HashMap::new(),
instance_to_symbol: [("".to_string(), "output".to_string())]
.iter()
.cloned()
.collect(),
var_id: Default::default(),
resolver: None,
})));
})
]);
// dummy resolver...
let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver {
id_to_type: HashMap::new(),
id_to_def: HashMap::new(),
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(),
id_to_def: builtins_def.into(),
class_names: Default::default(),
}) as Box<dyn SymbolResolver + Send + Sync>);
let mut functions = HashMap::new();
}.into();
let resolver = Arc::new(
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
);
for stmt in parser::parse_program(&program).unwrap().into_iter() {
if let StmtKind::FunctionDef {
name,
body,
args,
returns,
..
} = stmt.node
{
let args = args
.args
.into_iter()
.map(|arg| FuncArg {
name: arg.node.arg.to_string(),
ty: resolver
.parse_type_annotation(
&top_level.definitions.read(),
&mut unifier,
&primitives,
&arg.node
.annotation
.expect("expected type annotation in parameters"),
)
.unwrap(),
default_value: None,
})
.collect();
let ret = returns
.map(|r| {
resolver
.parse_type_annotation(
&top_level.definitions.read(),
&mut unifier,
&primitives,
&r,
)
.unwrap()
})
.unwrap_or(primitives.none);
let signature = FunSignature {
args,
ret,
vars: Default::default(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone())));
let id = top_level.definitions.read().len();
top_level
.definitions
.write()
.push(Arc::new(RwLock::new(TopLevelDef::Function {
name: name.clone(),
signature: fun_ty,
var_id: vec![],
instance_to_stmt: HashMap::new(),
instance_to_symbol: HashMap::new(),
resolver: None,
})));
functions.insert(name, (id, body, signature));
} else {
panic!("unsupported statement type");
let (name, def_id, ty) = composer.register_top_level(
stmt,
Some(resolver.clone()),
"__main__".into(),
).unwrap();
internal_resolver.add_id_def(name.clone(), def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);
}
}
let setup_time = SystemTime::now();
println!(
"Setup time: {}ms",
setup_time
.duration_since(start)
.unwrap()
.as_millis()
);
composer.start_analysis(true).unwrap();
let mut id_to_def: HashMap<_, _> = functions
.iter()
.map(|(k, v)| (k.clone(), DefinitionId(v.0)))
.collect();
id_to_def.insert("output".into(), DefinitionId(output_id));
let mut id_to_type: HashMap<_, _> = functions
.iter()
.map(|(k, v)| {
(
k.clone(),
unifier.add_ty(TypeEnum::TFunc(RefCell::new(v.2.clone()))),
)
})
.collect();
id_to_type.insert("output".into(), output_fun);
let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver {
class_names: Default::default(),
id_to_type,
id_to_def,
}) as Box<dyn SymbolResolver + Send + Sync>);
for (_, (id, ast, signature)) in functions.into_iter() {
if let TopLevelDef::Function {
resolver: r,
instance_to_stmt,
..
} = &mut *top_level.definitions.read()[id].write()
{
*r = Some(resolver.clone());
let return_type = if unifier.unioned(primitives.none, signature.ret) {
None
} else {
Some(signature.ret)
};
let mut function_data = FunctionData {
resolver: resolver.clone(),
bound_variables: Vec::new(),
return_type,
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers = HashSet::new();
let mut variable_mapping = HashMap::new();
for arg in signature.args.iter() {
identifiers.insert(arg.name.clone());
variable_mapping.insert(arg.name.clone(), arg.ty);
}
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping,
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone(),
};
let statements = ast
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let returned = inferencer
.check_block(&statements, &mut identifiers)
.unwrap();
if return_type.is_some() && !returned {
panic!("expected return");
}
instance_to_stmt.insert(
"".to_string(),
FunInstance {
body: statements,
unifier_id: 0,
calls,
subst: Default::default(),
},
);
}
}
let inference_time = SystemTime::now();
println!(
"Type inference time: {}ms",
inference_time
.duration_since(setup_time)
.unwrap()
.as_millis()
);
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(
&mut *top_level.definitions.write(),
))),
unifiers: Arc::new(RwLock::new(vec![(
unifier.get_shared_unifier(),
primitives,
)])),
});
let top_level = Arc::new(composer.make_top_level_context());
let instance = {
let defs = top_level.definitions.read();
@ -268,9 +87,10 @@ fn main() {
};
let signature = FunSignature {
args: vec![],
ret: primitives.int32,
ret: primitive.int32,
vars: HashMap::new(),
};
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "run".to_string(),
@ -281,14 +101,6 @@ fn main() {
calls: instance.calls,
};
let f = Arc::new(WithCall::new(Box::new(move |module| {
let codegen_time = SystemTime::now();
println!(
"Code generation time: {}ms",
codegen_time
.duration_since(inference_time)
.unwrap()
.as_millis()
);
let builder = PassManagerBuilder::create();
builder.set_optimization_level(OptimizationLevel::Aggressive);
let passes = PassManager::create(());
@ -312,13 +124,6 @@ fn main() {
.write_to_file(module, FileType::Object, Path::new("mandelbrot.o"))
.expect("couldn't write module to file");
println!(
"LLVM time: {}ms",
SystemTime::now()
.duration_since(codegen_time)
.unwrap()
.as_millis()
);
println!("IR:\n{}", module.print_to_string().to_str().unwrap());
})));