diff --git a/nac3core/src/top_level.rs b/nac3core/src/top_level.rs index ade98dd1..681700be 100644 --- a/nac3core/src/top_level.rs +++ b/nac3core/src/top_level.rs @@ -10,7 +10,7 @@ use itertools::Itertools; use parking_lot::{Mutex, RwLock}; use rustpython_parser::ast::{self, Stmt}; -#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub struct DefinitionId(pub usize); pub enum TopLevelDef { @@ -462,9 +462,9 @@ impl TopLevelComposer { } /// step 3, class fields and methods - // FIXME: need analyze base classes here - // FIXME: how to deal with self type - // FIXME: how to prevent cycles + // FIXME: analyze base classes here + // FIXME: deal with self type + // NOTE: prevent cycles only roughly done fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), String> { let mut def_ast_list = self.definition_ast_list.write(); let converted_top_level = &self.to_top_level_context(); @@ -472,20 +472,27 @@ impl TopLevelComposer { let to_be_analyzed_class = &mut self.to_be_analyzed_class; let unifier = &mut self.unifier; + // NOTE: roughly prevent infinite loop + let mut max_iter = to_be_analyzed_class.len() * 4; 'class: loop { - if to_be_analyzed_class.is_empty() { + if to_be_analyzed_class.is_empty() && { max_iter -= 1; max_iter > 0 } { break; } let class_ind = to_be_analyzed_class.remove(0).0; - let (class_name, class_body, class_resolver) = { + let (class_name, + class_body_ast, + class_bases_ast, + class_resolver, + class_ancestors + ) = { let (class_def, class_ast) = &mut def_ast_list[class_ind]; if let Some(ast::Located { - node: ast::StmtKind::ClassDef { name, body, .. }, .. + node: ast::StmtKind::ClassDef { name, body, bases, .. }, .. }) = class_ast.as_ref() { - if let TopLevelDef::Class { resolver, .. } = class_def.write().deref() { - (name, body, resolver.as_ref().unwrap().clone()) + if let TopLevelDef::Class { resolver, ancestors, .. } = class_def.write().deref() { + (name, body, bases, resolver.as_ref().unwrap().clone(), ancestors.clone()) } else { unreachable!() } @@ -494,11 +501,35 @@ impl TopLevelComposer { } }; + let all_base_class_analyzed = { + let not_yet_analyzed = to_be_analyzed_class.clone().into_iter().collect::>(); + let base = class_ancestors.clone().into_iter().collect::>(); + let intersection = not_yet_analyzed.intersection(&base).collect_vec(); + intersection.is_empty() + }; + if !all_base_class_analyzed { + to_be_analyzed_class.push(DefinitionId(class_ind)); + continue 'class; + } + + // get the bases type, can directly do this since it + // already pass the check in the previous stages + let class_bases_ty = class_bases_ast + .iter() + .filter_map(|x| { + class_resolver.as_ref().lock().parse_type_annotation( + converted_top_level, + unifier.borrow_mut(), + primitives, + x).ok() + }) + .collect_vec(); + // need these vectors to check re-defining methods, class fields // and store the parsed result in case some method cannot be typed for now let mut class_methods_parsing_result: Vec<(String, Type, DefinitionId)> = vec![]; let mut class_fields_parsing_result: Vec<(String, Type)> = vec![]; - for b in class_body { + for b in class_body_ast { if let ast::StmtKind::FunctionDef { args: method_args_ast, body: method_body_ast,