From 51f9f9c1e31c2855db98b134cdf47ffe2ffbb502 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 16 Aug 2024 17:19:09 +0800 Subject: [PATCH] WIP --- nac3core/src/codegen/expr.rs | 13 ++- nac3core/src/toplevel/composer.rs | 12 +- nac3core/src/toplevel/helper.rs | 109 +++++++++++++++--- nac3core/src/typecheck/type_inferencer/mod.rs | 98 +++++++++++++++- nac3standalone/demo/src/inheritance.py | 53 +++++++-- 5 files changed, 254 insertions(+), 31 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0817fce2..e9f6ccd1 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3025,14 +3025,19 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; // Handle Class Method calls - let id = if let TypeEnum::TObj { obj_id, .. } = - &*ctx.unifier.get_ty(value.custom.unwrap()) - { + // The attribute will be `DefinitionId` of the method if the call is to one of the parent methods + let func_id = attr.to_string().parse::(); + + let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { *obj_id } else { unreachable!() }; - let fun_id = { + + // Use the `DefinitionID` from attribute if it is available + let fun_id = if func_id.is_ok() { + DefinitionId(func_id.unwrap()) + } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 2f0f7e87..460d9a30 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1822,7 +1822,15 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + // Since AST stores class names without prepending `__module__.`, we split the name for search purposes + let class_name_only = class_name.to_string(); + let (_, class_name_only) = class_name_only.split_once('.').unwrap(); + + let all_inited = Self::get_all_assigned_field( + class_name_only.into(), + definition_ast_list, + body.as_slice(), + )?; for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9db..94cba0f6 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -733,7 +734,11 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + pub fn get_all_assigned_field( + class_name: StrRef, + ast: &Vec, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -769,30 +774,106 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field(class_name, ast, body.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_name, + ast, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(class_name, ast, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + class_name, + ast, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(class_name, ast, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + class_name, + ast, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_name, + ast, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field(class_name, ast, body.as_slice())?); + } + // Variables Initiated in function calls + ast::StmtKind::Expr { value, .. } => { + let ExprKind::Call { func, .. } = &value.node else { + continue; + }; + let ExprKind::Attribute { value, attr, .. } = &func.node else { + continue; + }; + let ExprKind::Name { id, .. } = &value.node else { + continue; + }; + // Need to conside the two cases: + // Case 1) Call to class function i.e. id = `self` + // Case 2) Call to class ancestor function i.e. id = ancestor_name + // We leave checking whether function in case 2 belonged to class ancestor or not to type checker + // + // According to current handling of `self`, function definition are fixed and do not change regardless + // of which object is passed as `self` i.e. virtual polymorphism is not supported + // Therefore, we change class name for case 2 to reflect behavior of our compiler + let new_class_name = if *id == "self".into() { class_name } else { *id }; + + let method_body = ast.iter().find_map(|def| { + let Some(ast::Located { + node: ast::StmtKind::ClassDef { name, body, .. }, + .. + }) = &def.1 + else { + return None; + }; + if *name == new_class_name { + body.iter().find_map(|m| { + let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else { + return None; + }; + if *name == *attr { + return Some(body.clone()); + } + None + }) + } else { + None + } + }); + + // If method body is none then method does not exist + if let Some(method_body) = method_body { + result.extend(Self::get_all_assigned_field( + new_class_name, + ast, + method_body.as_slice(), + )?); + } else { + return Err(HashSet::from([format!( + "{}.{} not found in class {new_class_name} at {}", + *id, *attr, value.location + )])); + } } ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + | ast::StmtKind::AnnAssign { .. } => {} _ => { unimplemented!() diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0408cf1c..7816170f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -12,6 +12,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::type_annotation::TypeAnnotation; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -102,6 +103,7 @@ pub struct Inferencer<'a> { } type InferenceError = HashSet; +type OverrideResult = Result>>, InferenceError>; struct NaiveFolder(); impl Fold<()> for NaiveFolder { @@ -1672,6 +1674,86 @@ impl<'a> Inferencer<'a> { Ok(None) } + /// Checks whether a class method is calling parent function + /// Returns [`None`] if its not a call to parent method, otherwise + /// returns a new `func` with class name replaced by `self` and method resolved to its `DefinitionID` + /// + /// e.g. A.f1(self, ...) returns Some(self.DefintionID(f1)) + fn check_overriding(&mut self, func: &ast::Expr<()>, args: &[ast::Expr<()>]) -> OverrideResult { + // `self` must be first argument for call to parent method + if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() { + if *id != "self".into() { + return Ok(None); + } + } else { + return Ok(None); + } + + let Located { + node: ExprKind::Attribute { value, attr: method_name, ctx }, location, .. + } = func + else { + return Ok(None); + }; + let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else { + return Ok(None); + }; + let zelf = &self.fold_expr(args[0].clone())?; + + // Check whether the method belongs to class ancestors + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let res = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + let res = ancestors.iter().find_map(|f| { + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else { + unreachable!() + }; + // Class names are stored as `__module__.class` + let name = name.to_string(); + let (_, name) = name.split_once('.').unwrap(); + if name == class_name.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *method_name { + return Some(*f); + } + None + }); + } + None + }); + res + } else { + None + } + }; + + match res { + Some(r) => { + + let mut new_func = func.clone(); + let mut new_value = value.clone(); + new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx }; + new_func.node = + ExprKind::Attribute { value: new_value.clone(), attr: *method_name, ctx: *ctx }; + + let mut new_func = self.fold_expr(new_func)?; + + let ExprKind::Attribute { value, .. } = new_func.node else { unreachable!() }; + new_func.node = ExprKind::Attribute { value, attr: r.2.0.to_string().into(), ctx: *ctx }; + new_func.custom = Some(r.1); + + Ok(Some(new_func)) + } + None => report_error( + format!("Method {class_name}.{method_name} not found in ancestor list").as_str(), + *location, + ), + } + } + fn fold_call( &mut self, location: Location, @@ -1685,8 +1767,20 @@ impl<'a> Inferencer<'a> { return Ok(spec_call_func); } - let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + // Check for call to parent method + let override_res = self.check_overriding(&func, &args)?; + let is_override = override_res.is_some(); + let func = if is_override { override_res.unwrap() } else { self.fold_expr(func)? }; + let func = Box::new(func); + + let mut args = + args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + + // TODO: Handle passing of self to functions to allow runtime lookup of functions to be called + // Currently removing `self` and using compile time function definitions + if is_override { + args.remove(0); + } let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) diff --git a/nac3standalone/demo/src/inheritance.py b/nac3standalone/demo/src/inheritance.py index d280e3a5..c908c322 100644 --- a/nac3standalone/demo/src/inheritance.py +++ b/nac3standalone/demo/src/inheritance.py @@ -10,23 +10,58 @@ class A: def __init__(self, a: int32): self.a = a - def f1(self): - self.f2() - - def f2(self): + def output_all_fields(self): output_int32(self.a) + + def set_a(self, a: int32): + self.a = a class B(A): b: int32 def __init__(self, b: int32): - self.a = b + 1 + A.__init__(self, b + 1) + self.set_b(b) + + def output_parent_fields(self): + A.output_all_fields(self) + + def output_all_fields(self): + A.output_all_fields(self) + output_int32(self.b) + + def set_b(self, b: int32): self.b = b +class C(B): + c: int32 + + def __init__(self, c: int32): + B.__init__(self, c + 1) + self.c = c + + def output_parent_fields(self): + B.output_all_fields(self) + + def output_all_fields(self): + B.output_all_fields(self) + output_int32(self.c) + + def set_c(self, c: int32): + self.c = c def run() -> int32: - aaa = A(5) - bbb = B(2) - aaa.f1() - bbb.f1() + ccc = C(10) + ccc.output_all_fields() + ccc.set_a(1) + ccc.set_b(2) + ccc.set_c(3) + ccc.output_all_fields() + + bbb = B(10) + bbb.set_a(9) + bbb.set_b(8) + bbb.output_all_fields() + ccc.output_all_fields() + return 0