diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index b9c4f662d..d54b3c51e 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -56,7 +56,7 @@ fn test_primitives() { "}; let statements = parse_program(source).unwrap(); - let composer = TopLevelComposer::new(); + let composer: TopLevelComposer = Default::default(); let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); @@ -205,7 +205,7 @@ fn test_simple_call() { "}; let statements_2 = parse_program(source_2).unwrap(); - let composer = TopLevelComposer::new(); + let composer: TopLevelComposer = Default::default(); let mut unifier = composer.unifier.clone(); let primitives = composer.primitives_ty; let top_level = Arc::new(composer.make_top_level_context()); diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index b18449d14..15523d1d9 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,3 +1,5 @@ +use std::cell::RefCell; + use rustpython_parser::ast::fold::Fold; use crate::typecheck::type_inferencer::{FunctionData, Inferencer}; @@ -20,54 +22,95 @@ pub struct TopLevelComposer { pub defined_function_name: HashSet, // get the class def id of a class method pub method_class: HashMap, + pub built_in_num: usize, } impl Default for TopLevelComposer { fn default() -> Self { - Self::new() + Self::new(vec![]).0 } } impl TopLevelComposer { /// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// resolver can later figure out primitive type definitions when passed a primitive type name - pub fn new() -> Self { + pub fn new(builtins: Vec<(String, FunSignature)>) -> (Self, HashMap, HashMap) { let primitives = Self::make_primitives(); - TopLevelComposer { - definition_ast_list: { - let top_level_def_list = vec![ - Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none"))), - ]; - let ast_list: Vec>> = vec![None, None, None, None, None]; - izip!(top_level_def_list, ast_list).collect_vec() - }, - primitives_ty: primitives.0, - unifier: primitives.1, - keyword_list: HashSet::from_iter(vec![ - "Generic".into(), - "virtual".into(), - "list".into(), - "tuple".into(), - "int32".into(), - "int64".into(), - "float".into(), - "bool".into(), - "none".into(), - "None".into(), - "self".into(), - "Kernel".into(), - "KernelImmutable".into(), - ]), - defined_class_method_name: Default::default(), - defined_class_name: Default::default(), - defined_function_name: Default::default(), - method_class: Default::default(), + let mut definition_ast_list = { + let top_level_def_list = vec![ + Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))), + Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))), + Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float"))), + Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool"))), + Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none"))), + ]; + let ast_list: Vec>> = vec![None, None, None, None, None]; + izip!(top_level_def_list, ast_list).collect_vec() + }; + let primitives_ty = primitives.0; + let mut unifier = primitives.1; + let keyword_list: HashSet = HashSet::from_iter(vec![ + "Generic".into(), + "virtual".into(), + "list".into(), + "tuple".into(), + "int32".into(), + "int64".into(), + "float".into(), + "bool".into(), + "none".into(), + "None".into(), + "self".into(), + "Kernel".into(), + "KernelImmutable".into(), + ]); + let mut defined_class_method_name: HashSet = Default::default(); + let mut defined_class_name: HashSet = Default::default(); + let mut defined_function_name: HashSet = Default::default(); + let method_class: HashMap = Default::default(); + + let mut built_in_id: HashMap = Default::default(); + let mut built_in_ty: HashMap = Default::default(); + + for (name, sig) in builtins { + let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig))); + built_in_ty.insert(name.clone(), fun_sig); + built_in_id.insert(name.clone(), DefinitionId(definition_ast_list.len())); + definition_ast_list.push(( + Arc::new(RwLock::new(TopLevelDef::Function { + name: name.clone(), + signature: fun_sig, + instance_to_stmt: HashMap::new(), + instance_to_symbol: [("".to_string(), name.clone())] + .iter() + .cloned() + .collect(), + var_id: Default::default(), + resolver: None, + })), + None + )); + defined_class_method_name.insert(name.clone()); + defined_class_name.insert(name.clone()); + defined_function_name.insert(name); } + + ( + TopLevelComposer { + built_in_num: definition_ast_list.len(), + definition_ast_list, + primitives_ty, + unifier, + keyword_list, + defined_class_method_name, + defined_class_name, + defined_function_name, + method_class, + }, + built_in_id, + built_in_ty, + ) } pub fn make_top_level_context(&self) -> TopLevelContext { @@ -275,7 +318,7 @@ impl TopLevelComposer { let primitives_store = &self.primitives_ty; // skip 5 to skip analyzing the primitives - for (class_def, class_ast) in def_list.iter().skip(5) { + for (class_def, class_ast) in def_list.iter().skip(self.built_in_num) { // only deal with class def here let mut class_def = class_def.write(); let (class_bases_ast, class_def_type_vars, class_resolver) = { @@ -376,7 +419,7 @@ impl TopLevelComposer { // first, only push direct parent into the list // skip 5 to skip analyzing the primitives - for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(5) { + for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.built_in_num) { let mut class_def = class_def.write(); let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { if let TopLevelDef::Class { ancestors, resolver, object_id, type_vars, .. } = @@ -440,7 +483,7 @@ impl TopLevelComposer { // second, get all ancestors let mut ancestors_store: HashMap> = Default::default(); // skip 5 to skip analyzing the primitives - for (class_def, _) in self.definition_ast_list.iter().skip(5) { + for (class_def, _) in self.definition_ast_list.iter().skip(self.built_in_num) { let class_def = class_def.read(); let (class_ancestors, class_id) = { if let TopLevelDef::Class { ancestors, object_id, .. } = class_def.deref() { @@ -462,7 +505,7 @@ impl TopLevelComposer { // insert the ancestors to the def list // skip 5 to skip analyzing the primitives - for (class_def, _) in self.definition_ast_list.iter_mut().skip(5) { + for (class_def, _) in self.definition_ast_list.iter_mut().skip(self.built_in_num) { let mut class_def = class_def.write(); let (class_ancestors, class_id, class_type_vars) = { if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = @@ -495,7 +538,7 @@ impl TopLevelComposer { let mut type_var_to_concrete_def: HashMap = HashMap::new(); // skip 5 to skip analyzing the primitives - for (class_def, class_ast) in def_ast_list.iter().skip(5) { + for (class_def, class_ast) in def_ast_list.iter().skip(self.built_in_num) { if matches!(&*class_def.read(), TopLevelDef::Class { .. }) { Self::analyze_single_class_methods_fields( class_def.clone(), @@ -516,7 +559,7 @@ impl TopLevelComposer { loop { let mut finished = true; - for (class_def, _) in def_ast_list.iter().skip(5) { + for (class_def, _) in def_ast_list.iter().skip(self.built_in_num) { let mut class_def = class_def.write(); if let TopLevelDef::Class { ancestors, .. } = class_def.deref() { // if the length of the ancestor is equal to the current depth @@ -575,7 +618,7 @@ impl TopLevelComposer { let primitives_store = &self.primitives_ty; // skip 5 to skip analyzing the primitives - for (function_def, function_ast) in def_list.iter().skip(5) { + for (function_def, function_ast) in def_list.iter().skip(self.built_in_num) { let mut function_def = function_def.write(); let function_def = function_def.deref_mut(); let function_ast = if let Some(x) = function_ast.as_ref() { @@ -1118,7 +1161,7 @@ impl TopLevelComposer { /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function fn analyze_function_instance(&mut self) -> Result<(), String> { - for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() { + for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) { let mut function_def = def.write(); if let TopLevelDef::Function { instance_to_stmt, diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 6c749b157..0578bd855 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -88,7 +88,7 @@ impl SymbolResolver for Resolver { "register" )] fn test_simple_register(source: Vec<&str>) { - let mut composer = TopLevelComposer::new(); + let mut composer: TopLevelComposer = Default::default(); for s in source { let ast = parse_program(s).unwrap(); @@ -126,7 +126,7 @@ fn test_simple_register(source: Vec<&str>) { "function compose" )] fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) { - let mut composer = TopLevelComposer::new(); + let mut composer: TopLevelComposer = Default::default(); let internal_resolver = Arc::new(ResolverInternal { id_to_def: Default::default(), @@ -151,7 +151,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s composer.start_analysis(true).unwrap(); - for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { + for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() { let def = &*def.read(); if let TopLevelDef::Function { signature, name, .. } = def { let ty_str = @@ -770,7 +770,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s )] fn test_analyze(source: Vec<&str>, res: Vec<&str>) { let print = false; - let mut composer = TopLevelComposer::new(); + let mut composer: TopLevelComposer = Default::default(); let internal_resolver = make_internal_resolver_with_tvar( vec![ @@ -816,7 +816,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { } } else { // skip 5 to skip primitives - for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { + for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() { let def = &*def.read(); if print { @@ -942,7 +942,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { )] fn test_inference(source: Vec<&str>, res: Vec<&str>) { let print = true; - let mut composer = TopLevelComposer::new(); + let mut composer: TopLevelComposer = Default::default(); let internal_resolver = make_internal_resolver_with_tvar( vec![ @@ -989,7 +989,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { } else { // skip 5 to skip primitives let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier}; - for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { + for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() { let def = &*def.read(); if let TopLevelDef::Function { instance_to_stmt, name, .. } = def { diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 8b0a16d2f..f8fe760ea 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -7,18 +7,34 @@ use nac3core::{ typedef::{Type, Unifier}, }, }; -use std::collections::HashMap; +use parking_lot::Mutex; +use std::{collections::HashMap, sync::Arc}; -#[derive(Clone)] -pub struct Resolver { - pub id_to_type: HashMap, - pub id_to_def: HashMap, - pub class_names: HashMap, +pub struct ResolverInternal { + pub id_to_type: Mutex>, + pub id_to_def: Mutex>, + pub class_names: Mutex>, } +impl ResolverInternal { + pub fn add_id_def(&self, id: String, def: DefinitionId) { + self.id_to_def.lock().insert(id, def); + } + + pub fn add_id_type(&self, id: String, ty: Type) { + self.id_to_type.lock().insert(id, ty); + } +} + +pub struct Resolver(pub Arc); + impl SymbolResolver for Resolver { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { - self.id_to_type.get(str).cloned() + let ret = self.0.id_to_type.lock().get(str).cloned(); + if ret.is_none() { + // println!("unknown here resolver {}", str); + } + ret } fn get_symbol_value(&self, _: &str) -> Option { @@ -30,6 +46,6 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: &str) -> Option { - self.id_to_def.get(id).cloned() + self.0.id_to_def.lock().get(id).cloned() } } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index f807dabd9..07b994157 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -6,24 +6,20 @@ use inkwell::{ targets::*, OptimizationLevel, }; +use nac3core::typecheck::type_inferencer::PrimitiveStore; use parking_lot::RwLock; -use rustpython_parser::{ - ast::{fold::Fold, StmtKind}, - parser, -}; -use std::{cell::RefCell, collections::HashMap, path::Path, sync::Arc}; +use rustpython_parser::parser; +use std::{collections::HashMap, path::Path, sync::Arc}; use nac3core::{ codegen::{CodeGenTask, WithCall, WorkerRegistry}, symbol_resolver::SymbolResolver, - toplevel::{DefinitionId, FunInstance, composer::TopLevelComposer, TopLevelContext, TopLevelDef}, - typecheck::{ - type_inferencer::{FunctionData, Inferencer}, - typedef::{FunSignature, FuncArg, TypeEnum}, - }, + toplevel::{composer::TopLevelComposer, TopLevelDef}, + typecheck::typedef::{FunSignature, FuncArg}, }; mod basic_symbol_resolver; +use basic_symbol_resolver::*; fn main() { Target::initialize_all(&InitializationConfig::default()); @@ -36,220 +32,43 @@ fn main() { } }; - let start = SystemTime::now(); + let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; + let (mut composer, builtins_def, builtins_ty) = TopLevelComposer::new(vec![ + ("output".into(), FunSignature { + args: vec![FuncArg { + name: "c".into(), + ty: primitive.int32, + default_value: None, + }], + ret: primitive.none, + vars: HashMap::new(), + }) + ]); - let composer = TopLevelComposer::new(); - let mut unifier = composer.unifier.clone(); - let primitives = composer.primitives_ty; - let top_level = Arc::new(composer.make_top_level_context()); - unifier.top_level = Some(top_level.clone()); - let output_fun = unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { - args: vec![FuncArg { - name: "c".into(), - ty: primitives.int32, - default_value: None, - }], - ret: primitives.none, - vars: HashMap::new(), - }))); - let output_id = top_level.definitions.read().len(); - top_level - .definitions - .write() - .push(Arc::new(RwLock::new(TopLevelDef::Function { - name: "output".into(), - signature: output_fun, - instance_to_stmt: HashMap::new(), - instance_to_symbol: [("".to_string(), "output".to_string())] - .iter() - .cloned() - .collect(), - var_id: Default::default(), - resolver: None, - }))); - - // dummy resolver... - let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver { - id_to_type: HashMap::new(), - id_to_def: HashMap::new(), + let internal_resolver: Arc = ResolverInternal { + id_to_type: builtins_ty.into(), + id_to_def: builtins_def.into(), class_names: Default::default(), - }) as Box); - let mut functions = HashMap::new(); + }.into(); + let resolver = Arc::new( + Box::new(Resolver(internal_resolver.clone())) as Box + ); for stmt in parser::parse_program(&program).unwrap().into_iter() { - if let StmtKind::FunctionDef { - name, - body, - args, - returns, - .. - } = stmt.node - { - let args = args - .args - .into_iter() - .map(|arg| FuncArg { - name: arg.node.arg.to_string(), - ty: resolver - .parse_type_annotation( - &top_level.definitions.read(), - &mut unifier, - &primitives, - &arg.node - .annotation - .expect("expected type annotation in parameters"), - ) - .unwrap(), - default_value: None, - }) - .collect(); - let ret = returns - .map(|r| { - resolver - .parse_type_annotation( - &top_level.definitions.read(), - &mut unifier, - &primitives, - &r, - ) - .unwrap() - }) - .unwrap_or(primitives.none); - let signature = FunSignature { - args, - ret, - vars: Default::default(), - }; - let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone()))); - let id = top_level.definitions.read().len(); - top_level - .definitions - .write() - .push(Arc::new(RwLock::new(TopLevelDef::Function { - name: name.clone(), - signature: fun_ty, - var_id: vec![], - instance_to_stmt: HashMap::new(), - instance_to_symbol: HashMap::new(), - resolver: None, - }))); - functions.insert(name, (id, body, signature)); - } else { - panic!("unsupported statement type"); + let (name, def_id, ty) = composer.register_top_level( + stmt, + Some(resolver.clone()), + "__main__".into(), + ).unwrap(); + internal_resolver.add_id_def(name.clone(), def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(name, ty); } } - let setup_time = SystemTime::now(); - println!( - "Setup time: {}ms", - setup_time - .duration_since(start) - .unwrap() - .as_millis() - ); + composer.start_analysis(true).unwrap(); - let mut id_to_def: HashMap<_, _> = functions - .iter() - .map(|(k, v)| (k.clone(), DefinitionId(v.0))) - .collect(); - id_to_def.insert("output".into(), DefinitionId(output_id)); - let mut id_to_type: HashMap<_, _> = functions - .iter() - .map(|(k, v)| { - ( - k.clone(), - unifier.add_ty(TypeEnum::TFunc(RefCell::new(v.2.clone()))), - ) - }) - .collect(); - id_to_type.insert("output".into(), output_fun); - - let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver { - class_names: Default::default(), - id_to_type, - id_to_def, - }) as Box); - - for (_, (id, ast, signature)) in functions.into_iter() { - if let TopLevelDef::Function { - resolver: r, - instance_to_stmt, - .. - } = &mut *top_level.definitions.read()[id].write() - { - *r = Some(resolver.clone()); - - let return_type = if unifier.unioned(primitives.none, signature.ret) { - None - } else { - Some(signature.ret) - }; - let mut function_data = FunctionData { - resolver: resolver.clone(), - bound_variables: Vec::new(), - return_type, - }; - let mut virtual_checks = Vec::new(); - let mut calls = HashMap::new(); - let mut identifiers = HashSet::new(); - let mut variable_mapping = HashMap::new(); - for arg in signature.args.iter() { - identifiers.insert(arg.name.clone()); - variable_mapping.insert(arg.name.clone(), arg.ty); - } - let mut inferencer = Inferencer { - top_level: &top_level, - function_data: &mut function_data, - unifier: &mut unifier, - variable_mapping, - primitives: &primitives, - virtual_checks: &mut virtual_checks, - calls: &mut calls, - defined_identifiers: identifiers.clone(), - }; - let statements = ast - .into_iter() - .map(|v| inferencer.fold_stmt(v)) - .collect::, _>>() - .unwrap(); - - let returned = inferencer - .check_block(&statements, &mut identifiers) - .unwrap(); - if return_type.is_some() && !returned { - panic!("expected return"); - } - - instance_to_stmt.insert( - "".to_string(), - FunInstance { - body: statements, - unifier_id: 0, - calls, - subst: Default::default(), - }, - ); - } - } - - let inference_time = SystemTime::now(); - println!( - "Type inference time: {}ms", - inference_time - .duration_since(setup_time) - .unwrap() - .as_millis() - ); - - let top_level = Arc::new(TopLevelContext { - definitions: Arc::new(RwLock::new(std::mem::take( - &mut *top_level.definitions.write(), - ))), - unifiers: Arc::new(RwLock::new(vec![( - unifier.get_shared_unifier(), - primitives, - )])), - }); + let top_level = Arc::new(composer.make_top_level_context()); let instance = { let defs = top_level.definitions.read(); @@ -268,9 +87,10 @@ fn main() { }; let signature = FunSignature { args: vec![], - ret: primitives.int32, + ret: primitive.int32, vars: HashMap::new(), }; + let task = CodeGenTask { subst: Default::default(), symbol_name: "run".to_string(), @@ -281,14 +101,6 @@ fn main() { calls: instance.calls, }; let f = Arc::new(WithCall::new(Box::new(move |module| { - let codegen_time = SystemTime::now(); - println!( - "Code generation time: {}ms", - codegen_time - .duration_since(inference_time) - .unwrap() - .as_millis() - ); let builder = PassManagerBuilder::create(); builder.set_optimization_level(OptimizationLevel::Aggressive); let passes = PassManager::create(()); @@ -312,13 +124,6 @@ fn main() { .write_to_file(module, FileType::Object, Path::new("mandelbrot.o")) .expect("couldn't write module to file"); - println!( - "LLVM time: {}ms", - SystemTime::now() - .duration_since(codegen_time) - .unwrap() - .as_millis() - ); println!("IR:\n{}", module.print_to_string().to_str().unwrap()); })));