diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 58ae94fd..b4617cf0 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1723,7 +1723,13 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + // let all_inited = Self::get_all_assigned_field(body.as_slice())?; + let all_inited = Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?; + for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 538e653e..550b19ce 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -677,7 +678,11 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + pub fn get_all_assigned_field( + definition_ast_list: &Vec, + def: &Arc>, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -713,32 +718,151 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); } - ast::StmtKind::Pass { .. } - | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + // If its a call to __init__function of ancestor extend with ancestor fields + ast::StmtKind::Expr { value, .. } => { + // Check if Expression is a function call to self + if let ExprKind::Call { func, args, .. } = &value.node { + if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node { + let class_def = def.read(); + let (ancestors, methods) = { + let mut class_methods: HashMap = + HashMap::new(); + let mut class_ancestors: HashMap< + StrRef, + HashMap, + > = HashMap::new(); + + if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def { + for m in methods { + class_methods.insert(m.0, m.2); + } + ancestors.iter().skip(1).for_each(|a| { + if let TypeAnnotation::CustomClass { id, .. } = a { + let anc_def = + definition_ast_list.get(id.0).unwrap().0.read(); + if let TopLevelDef::Class { name, methods, .. } = + &*anc_def + { + let mut temp: HashMap = + HashMap::new(); + for m in methods { + temp.insert(m.0, m.2); + } + // Remove module name suffix from name + let mut name_string = name.to_string(); + let split_loc = + name_string.find(|c| c == '.').unwrap() + 1; + class_ancestors.insert( + name_string.split_off(split_loc).into(), + temp, + ); + } + } + }); + } + (class_ancestors, class_methods) + }; + if let ExprKind::Name { id, .. } = value.node { + if id == "self".into() { + // Get Class methods and fields + let method_id = methods.get(fn_name); + if method_id.is_some() { + if let Some(fn_ast) = &definition_ast_list + .get(method_id.unwrap().0) + .unwrap() + .1 + { + if let ast::StmtKind::FunctionDef { body, .. } = + &fn_ast.node + { + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); + } + } + } + } else if let Some(ancestor_methods) = ancestors.get(&id) { + // First arg must be `self` when calling ancestor function + if let ExprKind::Name { id, .. } = args[0].node { + if id == "self".into() { + if let Some(method_id) = ancestor_methods.get(fn_name) { + if let Some(fn_ast) = + &definition_ast_list.get(method_id.0).unwrap().1 + { + if let ast::StmtKind::FunctionDef { + body, .. + } = &fn_ast.node + { + result.extend( + Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?, + ); + } + } + }; + } + } + } + } + } + } + } + ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {} _ => { + println!("{:?}", s.node); unimplemented!() } }