From a744b139ba275b182adce010b6c15a33f2145180 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 16 Aug 2024 17:42:09 +0800 Subject: [PATCH] core: allow Call and AnnAssign in init block --- nac3core/src/toplevel/composer.rs | 9 +- nac3core/src/toplevel/helper.rs | 146 +++++++++++++++++++++++++++--- 2 files changed, 139 insertions(+), 16 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 2f0f7e875..603a508e9 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1822,7 +1822,12 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + + let all_inited = Self::get_all_assigned_field( + object_id.0, + definition_ast_list, + body.as_slice(), + )?; for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9dba..29a662c54 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -733,7 +734,16 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + /// This function returns the fields that have been initialized in the `__init__` function of a class + /// The function takes as input: + /// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`) + /// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer` + /// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements + pub fn get_all_assigned_field( + class_id: usize, + definition_ast_list: &Vec, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -769,30 +779,138 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )? + .intersection(&Self::get_all_assigned_field( + class_id, + definition_ast_list, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )? + .intersection(&Self::get_all_assigned_field( + class_id, + definition_ast_list, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )?); + } + // Variables Initialized in function calls + ast::StmtKind::Expr { value, .. } => { + let ExprKind::Call { func, .. } = &value.node else { + continue; + }; + let ExprKind::Attribute { value, attr, .. } = &func.node else { + continue; + }; + let ExprKind::Name { id, .. } = &value.node else { + continue; + }; + // Need to consider the two cases: + // Case 1) Call to class function i.e. id = `self` + // Case 2) Call to class ancestor function i.e. id = ancestor_name + // We leave checking whether function in case 2 belonged to class ancestor or not to type checker + // + // According to current handling of `self`, function definition are fixed and do not change regardless + // of which object is passed as `self` i.e. virtual polymorphism is not supported + // Therefore, we change class id for case 2 to reflect behavior of our compiler + + let class_name = if *id == "self".into() { + let ast::StmtKind::ClassDef { name, .. } = + &definition_ast_list[class_id].1.as_ref().unwrap().node + else { + unreachable!() + }; + name + } else { + id + }; + + let parent_method = definition_ast_list.iter().find_map(|def| { + let ( + class_def, + Some(ast::Located { + node: ast::StmtKind::ClassDef { name, body, .. }, + .. + }), + ) = &def + else { + return None; + }; + let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read() + else { + unreachable!() + }; + + if name == class_name { + body.iter().find_map(|m| { + let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else { + return None; + }; + if *name == *attr { + return Some((body.clone(), class_id.0)); + } + None + }) + } else { + None + } + }); + + // If method body is none then method does not exist + if let Some((method_body, class_id)) = parent_method { + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + method_body.as_slice(), + )?); + } else { + return Err(HashSet::from([format!( + "{}.{} not found in class {class_name} at {}", + *id, *attr, value.location + )])); + } } ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + | ast::StmtKind::AnnAssign { .. } => {} _ => { unimplemented!()