From 6a0fb4daa17f4baaca05c2bfb8dd14482db1c427 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 16 Aug 2024 17:42:09 +0800 Subject: [PATCH 1/4] core: allow Call and AnnAssign in init block --- nac3core/src/toplevel/composer.rs | 9 +- nac3core/src/toplevel/helper.rs | 146 +++++++++++++++++++++++++++--- 2 files changed, 139 insertions(+), 16 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 2f0f7e87..603a508e 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1822,7 +1822,12 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + + let all_inited = Self::get_all_assigned_field( + object_id.0, + definition_ast_list, + body.as_slice(), + )?; for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9db..29a662c5 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -733,7 +734,16 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + /// This function returns the fields that have been initialized in the `__init__` function of a class + /// The function takes as input: + /// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`) + /// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer` + /// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements + pub fn get_all_assigned_field( + class_id: usize, + definition_ast_list: &Vec, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -769,30 +779,138 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )? + .intersection(&Self::get_all_assigned_field( + class_id, + definition_ast_list, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )? + .intersection(&Self::get_all_assigned_field( + class_id, + definition_ast_list, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + body.as_slice(), + )?); + } + // Variables Initialized in function calls + ast::StmtKind::Expr { value, .. } => { + let ExprKind::Call { func, .. } = &value.node else { + continue; + }; + let ExprKind::Attribute { value, attr, .. } = &func.node else { + continue; + }; + let ExprKind::Name { id, .. } = &value.node else { + continue; + }; + // Need to consider the two cases: + // Case 1) Call to class function i.e. id = `self` + // Case 2) Call to class ancestor function i.e. id = ancestor_name + // We leave checking whether function in case 2 belonged to class ancestor or not to type checker + // + // According to current handling of `self`, function definition are fixed and do not change regardless + // of which object is passed as `self` i.e. virtual polymorphism is not supported + // Therefore, we change class id for case 2 to reflect behavior of our compiler + + let class_name = if *id == "self".into() { + let ast::StmtKind::ClassDef { name, .. } = + &definition_ast_list[class_id].1.as_ref().unwrap().node + else { + unreachable!() + }; + name + } else { + id + }; + + let parent_method = definition_ast_list.iter().find_map(|def| { + let ( + class_def, + Some(ast::Located { + node: ast::StmtKind::ClassDef { name, body, .. }, + .. + }), + ) = &def + else { + return None; + }; + let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read() + else { + unreachable!() + }; + + if name == class_name { + body.iter().find_map(|m| { + let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else { + return None; + }; + if *name == *attr { + return Some((body.clone(), class_id.0)); + } + None + }) + } else { + None + } + }); + + // If method body is none then method does not exist + if let Some((method_body, class_id)) = parent_method { + result.extend(Self::get_all_assigned_field( + class_id, + definition_ast_list, + method_body.as_slice(), + )?); + } else { + return Err(HashSet::from([format!( + "{}.{} not found in class {class_name} at {}", + *id, *attr, value.location + )])); + } } ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + | ast::StmtKind::AnnAssign { .. } => {} _ => { unimplemented!() -- 2.44.2 From d1a833097afe3496160a2aff974de6e99ec959f0 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 16 Aug 2024 17:43:05 +0800 Subject: [PATCH 2/4] core: add support for simple polymorphism --- nac3core/src/codegen/expr.rs | 9 +- nac3core/src/typecheck/type_inferencer/mod.rs | 98 ++++++++++++++++++- 2 files changed, 104 insertions(+), 3 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0817fce2..2fcf365d 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3025,6 +3025,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; // Handle Class Method calls + // The attribute will be `DefinitionId` of the method if the call is to one of the parent methods + let func_id = attr.to_string().parse::(); + let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { @@ -3032,7 +3035,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { unreachable!() }; - let fun_id = { + + // Use the `DefinitionID` from attribute if it is available + let fun_id = if let Ok(func_id) = func_id { + DefinitionId(func_id) + } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0408cf1c..fc9ce2ec 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -12,6 +12,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::type_annotation::TypeAnnotation; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -102,6 +103,7 @@ pub struct Inferencer<'a> { } type InferenceError = HashSet; +type OverrideResult = Result>>, InferenceError>; struct NaiveFolder(); impl Fold<()> for NaiveFolder { @@ -1672,6 +1674,86 @@ impl<'a> Inferencer<'a> { Ok(None) } + /// Checks whether a class method is calling parent function + /// Returns [`None`] if its not a call to parent method, otherwise + /// returns a new `func` with class name replaced by `self` and method resolved to its `DefinitionID` + /// + /// e.g. A.f1(self, ...) returns Some(self.{DefintionID(f1)}) + fn check_overriding(&mut self, func: &ast::Expr<()>, args: &[ast::Expr<()>]) -> OverrideResult { + // `self` must be first argument for call to parent method + if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() { + if *id != "self".into() { + return Ok(None); + } + } else { + return Ok(None); + } + + let Located { + node: ExprKind::Attribute { value, attr: method_name, ctx }, location, .. + } = func + else { + return Ok(None); + }; + let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else { + return Ok(None); + }; + let zelf = &self.fold_expr(args[0].clone())?; + + // Check whether the method belongs to class ancestors + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let res = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + let res = ancestors.iter().find_map(|f| { + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else { + unreachable!() + }; + // Class names are stored as `__module__.class` + let name = name.to_string(); + let (_, name) = name.rsplit_once('.').unwrap(); + if name == class_name.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *method_name { + return Some(*f); + } + None + }); + } + None + }); + res + } else { + None + } + }; + + match res { + Some(r) => { + let mut new_func = func.clone(); + let mut new_value = value.clone(); + new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx }; + new_func.node = + ExprKind::Attribute { value: new_value.clone(), attr: *method_name, ctx: *ctx }; + + let mut new_func = self.fold_expr(new_func)?; + + let ExprKind::Attribute { value, .. } = new_func.node else { unreachable!() }; + new_func.node = + ExprKind::Attribute { value, attr: r.2 .0.to_string().into(), ctx: *ctx }; + new_func.custom = Some(r.1); + + Ok(Some(new_func)) + } + None => report_error( + format!("Method {class_name}.{method_name} not found in ancestor list").as_str(), + *location, + ), + } + } + fn fold_call( &mut self, location: Location, @@ -1685,8 +1767,20 @@ impl<'a> Inferencer<'a> { return Ok(spec_call_func); } - let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + // Check for call to parent method + let override_res = self.check_overriding(&func, &args)?; + let is_override = override_res.is_some(); + let func = if is_override { override_res.unwrap() } else { self.fold_expr(func)? }; + let func = Box::new(func); + + let mut args = + args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + + // TODO: Handle passing of self to functions to allow runtime lookup of functions to be called + // Currently removing `self` and using compile time function definitions + if is_override { + args.remove(0); + } let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) -- 2.44.2 From 415c78d23b710dc8ea0b49993f2cb8dcda3a2102 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 16 Aug 2024 17:43:42 +0800 Subject: [PATCH 3/4] standalone: add tests for polymorphism --- nac3standalone/demo/src/inheritance.py | 53 +++++++++++++++++++++----- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/nac3standalone/demo/src/inheritance.py b/nac3standalone/demo/src/inheritance.py index d280e3a5..c908c322 100644 --- a/nac3standalone/demo/src/inheritance.py +++ b/nac3standalone/demo/src/inheritance.py @@ -10,23 +10,58 @@ class A: def __init__(self, a: int32): self.a = a - def f1(self): - self.f2() - - def f2(self): + def output_all_fields(self): output_int32(self.a) + + def set_a(self, a: int32): + self.a = a class B(A): b: int32 def __init__(self, b: int32): - self.a = b + 1 + A.__init__(self, b + 1) + self.set_b(b) + + def output_parent_fields(self): + A.output_all_fields(self) + + def output_all_fields(self): + A.output_all_fields(self) + output_int32(self.b) + + def set_b(self, b: int32): self.b = b +class C(B): + c: int32 + + def __init__(self, c: int32): + B.__init__(self, c + 1) + self.c = c + + def output_parent_fields(self): + B.output_all_fields(self) + + def output_all_fields(self): + B.output_all_fields(self) + output_int32(self.c) + + def set_c(self, c: int32): + self.c = c def run() -> int32: - aaa = A(5) - bbb = B(2) - aaa.f1() - bbb.f1() + ccc = C(10) + ccc.output_all_fields() + ccc.set_a(1) + ccc.set_b(2) + ccc.set_c(3) + ccc.output_all_fields() + + bbb = B(10) + bbb.set_a(9) + bbb.set_b(8) + bbb.output_all_fields() + ccc.output_all_fields() + return 0 -- 2.44.2 From d26c75837fcfcb633185fc8ce12bb891f400840d Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 22 Aug 2024 16:44:31 +0800 Subject: [PATCH 4/4] core: improve error messages --- nac3core/src/typecheck/type_inferencer/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index fc9ce2ec..a5b8cd49 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1748,7 +1748,7 @@ impl<'a> Inferencer<'a> { Ok(Some(new_func)) } None => report_error( - format!("Method {class_name}.{method_name} not found in ancestor list").as_str(), + format!("Ancestor method [{class_name}.{method_name}] should be defined with same decorator as its overridden version").as_str(), *location, ), } -- 2.44.2