From 54b4572c5fa04119c4838bffc3f8726c4e8bcc7c Mon Sep 17 00:00:00 2001 From: ychenfo Date: Mon, 6 Sep 2021 19:23:04 +0800 Subject: [PATCH] nac3core: allow interior mutability to dyn symbolresolver, add add_id_def to symbolresolver trait, remove primitive from top level def list --- nac3core/src/codegen/expr.rs | 2 +- nac3core/src/codegen/mod.rs | 4 +- nac3core/src/codegen/test.rs | 18 +++-- nac3core/src/symbol_resolver.rs | 1 + nac3core/src/toplevel/helper.rs | 4 +- nac3core/src/toplevel/mod.rs | 54 ++++++-------- nac3core/src/toplevel/test.rs | 70 ++++++++++++++++--- nac3core/src/toplevel/type_annotation.rs | 7 +- nac3core/src/typecheck/function_check.rs | 2 +- nac3core/src/typecheck/type_inferencer/mod.rs | 22 +++--- .../src/typecheck/type_inferencer/test.rs | 12 ++-- nac3standalone/src/basic_symbol_resolver.rs | 4 ++ nac3standalone/src/main.rs | 8 +-- 13 files changed, 132 insertions(+), 76 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 013785e5..cf2590af 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -609,7 +609,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ExprKind::Call { func, args, keywords } => { if let ExprKind::Name { id, .. } = &func.as_ref().node { // TODO: handle primitive casts and function pointers - let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier"); + let fun = self.resolver.lock().get_identifier_def(&id).expect("Unknown identifier"); let mut params = args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); let kw_iter = keywords.iter().map(|kw| { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index ccedbe5d..79c96997 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -38,7 +38,7 @@ pub struct CodeGenContext<'ctx, 'a> { pub module: Module<'ctx>, pub top_level: &'a TopLevelContext, pub unifier: Unifier, - pub resolver: Arc>, + pub resolver: Arc>>, pub var_assignment: HashMap>, pub type_cache: HashMap>, pub primitives: PrimitiveStore, @@ -190,7 +190,7 @@ pub struct CodeGenTask { pub body: Vec>>, pub calls: HashMap, pub unifier: (SharedUnifier, PrimitiveStore), - pub resolver: Arc>, + pub resolver: Arc>>, } fn get_llvm_type<'ctx>( diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 6f7f7b1a..dd61ae11 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -9,7 +9,7 @@ use crate::{ }, }; use indoc::indoc; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use rustpython_parser::{ast::fold::Fold, parser::parse_program}; use std::cell::RefCell; use std::collections::{HashMap, HashSet}; @@ -43,6 +43,10 @@ impl SymbolResolver for Resolver { fn get_identifier_def(&self, id: &str) -> Option { self.id_to_def.read().get(id).cloned() } + + fn add_id_def(&mut self, _: String, _: DefinitionId) { + unimplemented!() + } } #[test] @@ -60,11 +64,17 @@ fn test_primitives() { let top_level = Arc::new(composer.make_top_level_context()); unifier.top_level = Some(top_level.clone()); - let resolver = Arc::new(Box::new(Resolver { + // let resolver = Arc::new(Mutex::new(Resolver { + // id_to_type: HashMap::new(), + // id_to_def: RwLock::new(HashMap::new()), + // class_names: Default::default(), + // }) as Mutex); + + let resolver = Arc::new(Mutex::new(Box::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()), class_names: Default::default(), - }) as Box); + }) as Box)); let threads = ["test"]; let signature = FunSignature { @@ -226,7 +236,7 @@ fn test_simple_call() { class_names: Default::default(), }); resolver.add_id_def("foo".to_string(), DefinitionId(foo_id)); - let resolver = Arc::new(resolver as Box); + let resolver = Arc::new(Mutex::new(resolver as Box)); if let TopLevelDef::Function { resolver: r, .. } = &mut *top_level.definitions.read()[foo_id].write() diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2efb707b..ecf281e1 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -34,6 +34,7 @@ pub trait SymbolResolver { fn get_identifier_def(&self, str: &str) -> Option; fn get_symbol_value(&self, str: &str) -> Option; fn get_symbol_location(&self, str: &str) -> Option; + fn add_id_def(&mut self, id: String, def_id: DefinitionId); // handle function call etc. } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index cc2c3277..012f2041 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -37,7 +37,7 @@ impl TopLevelComposer { /// when first regitering, the type_vars, fields, methods, ancestors are invalid pub fn make_top_level_class_def( index: usize, - resolver: Option>>, + resolver: Option>>>, name: &str, ) -> TopLevelDef { TopLevelDef::Class { @@ -55,7 +55,7 @@ impl TopLevelComposer { pub fn make_top_level_function_def( name: String, ty: Type, - resolver: Option>>, + resolver: Option>>>, ) -> TopLevelDef { TopLevelDef::Function { name, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 86f9f420..cf24c926 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -1,10 +1,4 @@ -use std::{ - borrow::BorrowMut, - collections::{HashMap, HashSet}, - iter::FromIterator, - ops::{Deref, DerefMut}, - sync::Arc, -}; +use std::{borrow::{Borrow, BorrowMut}, collections::{HashMap, HashSet}, iter::FromIterator, ops::{Deref, DerefMut}, sync::Arc}; use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; @@ -13,7 +7,7 @@ use crate::{ typecheck::{type_inferencer::CodeLocation, typedef::CallId}, }; use itertools::{izip, Itertools}; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::{self, Stmt}; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] @@ -48,7 +42,7 @@ pub enum TopLevelDef { // ancestor classes, including itself. ancestors: Vec, // symbol resolver of the module defined the class, none if it is built-in type - resolver: Option>>, + resolver: Option>>>, }, Function { // prefix for symbol, should be unique globally, and not ending with numbers @@ -69,7 +63,7 @@ pub enum TopLevelDef { /// rigid type variables that would be substituted when the function is instantiated. instance_to_stmt: HashMap, // symbol resolver of the module defined the class - resolver: Option>>, + resolver: Option>>>, }, Initializer { class_id: DefinitionId, @@ -108,18 +102,8 @@ impl TopLevelComposer { pub fn new() -> Self { let primitives = Self::make_primitives(); - let top_level_def_list = vec![ - Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none"))), - ]; - - let ast_list: Vec>> = vec![None, None, None, None, None]; - TopLevelComposer { - definition_ast_list: izip!(top_level_def_list, ast_list).collect_vec(), + definition_ast_list: Default::default(), primitives_ty: primitives.0, unifier: primitives.1, // class_method_to_def_id: Default::default(), @@ -162,7 +146,7 @@ impl TopLevelComposer { pub fn register_top_level( &mut self, ast: ast::Stmt<()>, - resolver: Option>>, + resolver: Option>>>, ) -> Result<(String, DefinitionId), String> { let defined_class_name = &mut self.defined_class_name; let defined_class_method_name = &mut self.defined_class_method_name; @@ -373,7 +357,7 @@ impl TopLevelComposer { let type_vars = type_var_list .into_iter() .map(|e| { - class_resolver.parse_type_annotation( + class_resolver.lock().parse_type_annotation( &temp_def_list, unifier, primitives_store, @@ -530,15 +514,17 @@ impl TopLevelComposer { let mut type_var_to_concrete_def: HashMap = HashMap::new(); for (class_def, class_ast) in def_ast_list { - Self::analyze_single_class_methods_fields( - class_def.clone(), - &class_ast.as_ref().unwrap().node, - &temp_def_list, - unifier, - primitives, - &mut type_var_to_concrete_def, - &self.keyword_list, - )? + if matches!(&*class_def.read(), TopLevelDef::Class { .. }) { + Self::analyze_single_class_methods_fields( + class_def.clone(), + &class_ast.as_ref().unwrap().node, + &temp_def_list, + unifier, + primitives, + &mut type_var_to_concrete_def, + &self.keyword_list, + )? + } } // handle the inheritanced methods and fields @@ -615,7 +601,7 @@ impl TopLevelComposer { let mut defined_paramter_name: HashSet = HashSet::new(); let have_unique_fuction_parameter_name = args.args.iter().all(|x| { defined_paramter_name.insert(x.node.arg.clone()) - && keyword_list.contains(&x.node.arg) + && !keyword_list.contains(&x.node.arg) && "self" != x.node.arg }); if !have_unique_fuction_parameter_name { @@ -643,7 +629,7 @@ impl TopLevelComposer { primitives_store, annotation, )?; - + let type_vars_within = get_type_var_contained_in_type_annotation(&type_annotation) .into_iter() diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 3dd260ce..36aabe05 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -8,12 +8,9 @@ use crate::{ }, }; use indoc::indoc; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use rustpython_parser::{ast::fold::Fold, parser::parse_program}; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::{borrow::BorrowMut, collections::{HashMap, HashSet}, sync::Arc}; use test_case::test_case; use super::TopLevelComposer; @@ -24,12 +21,6 @@ struct Resolver { class_names: HashMap, } -impl Resolver { - pub fn add_id_def(&mut self, id: String, def: DefinitionId) { - self.id_to_def.insert(id, def); - } -} - impl SymbolResolver for Resolver { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { self.id_to_type.get(str).cloned() @@ -46,8 +37,14 @@ impl SymbolResolver for Resolver { fn get_identifier_def(&self, id: &str) -> Option { self.id_to_def.get(id).cloned() } + + fn add_id_def(&mut self, id: String, def: DefinitionId) { + self.id_to_def.insert(id, def); + } } + + #[test_case( vec![ indoc! {" @@ -89,3 +86,54 @@ fn test_simple_register(source: Vec<&str>) { composer.register_top_level(ast, None).unwrap(); } } + +#[test_case( + vec![ + indoc! {" + def fun(a: int) -> int: + return a + "}, + // indoc! {" + // class A: + // def __init__(self): + // self.a: int = 3 + // "}, + // indoc! {" + // class B: + // def __init__(self): + // self.b: float = 4.3 + + // def fun(self): + // self.b = self.b + 3.0 + // "}, + // indoc! {" + // def foo(a: float): + // a + 1.0 + // "}, + // indoc! {" + // class C(B): + // def __init__(self): + // self.c: int = 4 + // self.a: bool = True + // "} + ] +)] +fn test_simple_analyze(source: Vec<&str>) { + let mut composer = TopLevelComposer::new(); + + let resolver = Arc::new(Mutex::new(Box::new(Resolver { + id_to_def: Default::default(), + id_to_type: Default::default(), + class_names: Default::default(), + }) as Box)); + + for s in source { + let ast = parse_program(s).unwrap(); + let ast = ast[0].clone(); + + let (id, def_id) = composer.register_top_level(ast, Some(resolver.clone())).unwrap(); + resolver.lock().add_id_def(id, def_id); + } + + composer.start_analysis().unwrap(); +} diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 960e982d..a38d99d4 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -19,7 +19,7 @@ pub enum TypeAnnotation { } pub fn parse_ast_to_type_annotation_kinds( - resolver: &Box, + resolver: &Mutex>, top_level_defs: &[Arc>], unifier: &mut Unifier, primitives: &PrimitiveStore, @@ -33,7 +33,7 @@ pub fn parse_ast_to_type_annotation_kinds( "bool" => Ok(TypeAnnotation::PrimitiveKind(primitives.bool)), "None" => Ok(TypeAnnotation::PrimitiveKind(primitives.none)), x => { - if let Some(obj_id) = resolver.get_identifier_def(x) { + if let Some(obj_id) = resolver.lock().get_identifier_def(x) { let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { type_vars, .. } = &*def { // also check param number here @@ -47,7 +47,7 @@ pub fn parse_ast_to_type_annotation_kinds( } else { Err("function cannot be used as a type".into()) } - } else if let Some(ty) = resolver.get_symbol_type(unifier, primitives, id) { + } else if let Some(ty) = resolver.lock().get_symbol_type(unifier, primitives, id) { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { Ok(TypeAnnotation::TypeVarKind(ty)) } else { @@ -120,6 +120,7 @@ pub fn parse_ast_to_type_annotation_kinds( return Err("keywords cannot be class name".into()); } let obj_id = resolver + .lock() .get_identifier_def(id) .ok_or_else(|| "unknown class name".to_string())?; let def = top_level_defs[obj_id.0].read(); diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 03306615..4988a3dc 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -57,7 +57,7 @@ impl<'a> Inferencer<'a> { match &expr.node { ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { - if self.function_data.resolver.get_identifier_def(id).is_some() { + if self.function_data.resolver.lock().get_identifier_def(id).is_some() { defined_identifiers.insert(id.clone()); } else { return Err(format!( diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c5c3134e..8ff02004 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -7,6 +7,7 @@ use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; use super::{magic_methods::*, typedef::CallId}; use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; use itertools::izip; +use parking_lot::Mutex; use rustpython_parser::ast::{ self, fold::{self, Fold}, @@ -38,7 +39,7 @@ pub struct PrimitiveStore { } pub struct FunctionData { - pub resolver: Arc>, + pub resolver: Arc>>, pub return_type: Option, pub bound_variables: Vec, } @@ -85,7 +86,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { return Err(format!("declaration without definition is not yet supported, at {}", node.location)) }; let top_level_defs = self.top_level.definitions.read(); - let annotation_type = self.function_data.resolver.parse_type_annotation( + let annotation_type = self.function_data.resolver.lock().parse_type_annotation( top_level_defs.as_slice(), self.unifier, &self.primitives, @@ -160,7 +161,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Name { id, .. } => { if !self.defined_identifiers.contains(id) { - if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() { + if self.function_data.resolver.lock().get_identifier_def(id.as_str()).is_some() { self.defined_identifiers.insert(id.clone()); } else { return Err(format!( @@ -400,7 +401,7 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let ty = if let Some(arg) = args.pop() { let top_level_defs = self.top_level.definitions.read(); - self.function_data.resolver.parse_type_annotation( + self.function_data.resolver.lock().parse_type_annotation( top_level_defs.as_slice(), self.unifier, self.primitives, @@ -478,13 +479,14 @@ impl<'a> Inferencer<'a> { if let Some(ty) = self.variable_mapping.get(id) { Ok(*ty) } else { - Ok(self - .function_data - .resolver - .get_symbol_type(self.unifier, self.primitives, id) + let resolver = self.function_data.resolver.lock(); + let variable_mapping = &mut self.variable_mapping; + let unifier = &mut self.unifier; + Ok(resolver + .get_symbol_type(unifier, self.primitives, id) .unwrap_or_else(|| { - let ty = self.unifier.get_fresh_var().0; - self.variable_mapping.insert(id.to_string(), ty); + let ty = unifier.get_fresh_var().0; + variable_mapping.insert(id.to_string(), ty); ty })) } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index c3be237b..f8946c81 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -33,6 +33,10 @@ impl SymbolResolver for Resolver { fn get_identifier_def(&self, id: &str) -> Option { self.id_to_def.get(id).cloned() } + + fn add_id_def(&mut self, _: String, _: DefinitionId) { + unimplemented!() + } } struct TestEnvironment { @@ -92,11 +96,11 @@ impl TestEnvironment { let mut identifier_mapping = HashMap::new(); identifier_mapping.insert("None".into(), none); - let resolver = Arc::new(Box::new(Resolver { + let resolver = Arc::new(Mutex::new(Box::new(Resolver { id_to_type: identifier_mapping.clone(), id_to_def: Default::default(), class_names: Default::default(), - }) as Box); + }) as Box)); TestEnvironment { top_level: TopLevelContext { @@ -275,7 +279,7 @@ impl TestEnvironment { unifiers: Default::default(), }; - let resolver = Arc::new(Box::new(Resolver { + let resolver = Arc::new(Mutex::new(Box::new(Resolver { id_to_type: identifier_mapping.clone(), id_to_def: [ ("Foo".into(), DefinitionId(5)), @@ -286,7 +290,7 @@ impl TestEnvironment { .cloned() .collect(), class_names, - }) as Box); + }) as Box)); TestEnvironment { unifier, diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 8b0a16d2..b74c27fd 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -32,4 +32,8 @@ impl SymbolResolver for Resolver { fn get_identifier_def(&self, id: &str) -> Option { self.id_to_def.get(id).cloned() } + + fn add_id_def(&mut self, _: String, _: DefinitionId) { + unimplemented!(); + } } diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 3378d386..682ca303 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -6,7 +6,7 @@ use inkwell::{ targets::*, OptimizationLevel, }; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use rustpython_parser::{ ast::{fold::Fold, StmtKind}, parser, @@ -164,11 +164,11 @@ fn main() { .collect(); id_to_type.insert("output".into(), output_fun); - let resolver = Arc::new(Box::new(basic_symbol_resolver::Resolver { + let resolver = Arc::new(Mutex::new(Box::new(basic_symbol_resolver::Resolver { class_names: Default::default(), id_to_type, id_to_def, - }) as Box); + }) as Box)); for (_, (id, ast, signature)) in functions.into_iter() { if let TopLevelDef::Function { @@ -253,7 +253,7 @@ fn main() { let instance = { let defs = top_level.definitions.read(); - let mut instance = defs[resolver.get_identifier_def("run").unwrap().0].write(); + let mut instance = defs[resolver.lock().get_identifier_def("run").unwrap().0].write(); if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol,