diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 9dd980d2..5819be38 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1169,20 +1169,20 @@ impl TopLevelComposer { methods, fields, type_vars, - name, + name: class_name, object_id, resolver: _, .. } = &*class_def { - let mut has_init = false; + let mut init_id: Option = None; // get the class contructor type correct let (contor_args, contor_type_vars) = { let mut constructor_args: Vec = Vec::new(); let mut type_vars: HashMap = HashMap::new(); - for (name, func_sig, ..) in methods { + for (name, func_sig, id) in methods { if name == "__init__" { - has_init = true; + init_id = Some(*id); if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() { let FunSignature { args, vars, .. } = &*sig.borrow(); constructor_args.extend_from_slice(args); @@ -1207,11 +1207,22 @@ impl TopLevelComposer { self.unifier.unify(constructor.unwrap(), contor_type)?; // class field instantiation check - // TODO: this is a really simple one, more check later - if !has_init && !fields.is_empty() { - return Err(format!("fields of class {} not fully initialized", name)) + if let (Some(init_id), false) = (init_id, fields.is_empty()) { + let init_ast = + self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); + if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { + if name != "__init__" { + unreachable!("must be init function here") + } + let all_inited = Self::get_all_assigned_field(body.as_slice())?; + if fields.iter().any(|(x, _)| !all_inited.contains(x)) { + return Err(format!( + "fields of class {} not fully initialized", + class_name + )); + } + } } - } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 114c345d..c024631c 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -268,4 +268,73 @@ impl TopLevelComposer { unifier, ) } + + pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result, String> { + let mut result: HashSet = HashSet::new(); + for s in stmts { + match &s.node { + ast::StmtKind::AnnAssign { target, .. } + if { + if let ast::ExprKind::Attribute { value, .. } = &target.node { + if let ast::ExprKind::Name { id, .. } = &value.node { + id == "self" + } else { + false + } + } else { + false + } + } => + { + return Err(format!( + "redundant type annotation for class fields at {}", + s.location + )) + } + ast::StmtKind::Assign { targets, .. } => { + for t in targets { + if let ast::ExprKind::Attribute { value, attr, .. } = &t.node { + if let ast::ExprKind::Name { id, .. } = &value.node { + if id == "self" { + result.insert(attr.clone()); + } + } + } + } + } + // TODO: do not check for For and While? + ast::StmtKind::For { body, orelse, .. } + | ast::StmtKind::While { body, orelse, .. } => { + result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + } + ast::StmtKind::If { body, orelse, .. } => { + let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? + .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) + .cloned() + .collect::>(); + result.extend(inited_for_sure); + } + ast::StmtKind::Try { body, orelse, finalbody, .. } => { + let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? + .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) + .cloned() + .collect::>(); + result.extend(inited_for_sure); + result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + } + ast::StmtKind::With { body, .. } => { + result.extend(Self::get_all_assigned_field(body.as_slice())?); + } + ast::StmtKind::Pass => {} + ast::StmtKind::Assert { .. } => {} + ast::StmtKind::Expr { .. } => {} + + _ => { + unimplemented!() + } + } + } + Ok(result) + } } diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 730e9344..54893d36 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -870,7 +870,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { return SELF def sum(self) -> int32: if self.a == 0: - return self.a + self + return self.a else: a = self.a self.a = self.a - 1 @@ -898,7 +898,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { return ret * ret "}, indoc! {" - def sum3(l: list[V]) -> V: + def sum_three(l: list[V]) -> V: return l[0] + l[1] + l[2] "}, indoc! {" @@ -921,7 +921,11 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { b: bool def __init__(self, aa: G): self.a = aa - self.b = True + if 2 > 1: + self.b = True + else: + # self.b = False + pass def fun(self, a: G) -> list[G]: ret = [a, self.a] return ret if self.b else self.fun(self.a) @@ -930,6 +934,30 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { vec![]; "type var class" )] +#[test_case( + vec![ + indoc! {" + class A: + def fun(self): + 1 + 2 + "}, + indoc!{" + class B: + a: int32 + b: bool + def __init__(self): + # self.b = False + if 3 > 2: + self.a = 3 + self.b = False + else: + self.a = 4 + self.b = True + "} + ], + vec![]; + "no_init_inst_check" +)] fn test_inference(source: Vec<&str>, res: Vec<&str>) { let print = true; let mut composer: TopLevelComposer = Default::default();