diff --git a/flake.nix b/flake.nix index 07cea776..7987b259 100644 --- a/flake.nix +++ b/flake.nix @@ -180,7 +180,9 @@ clippy pre-commit rustfmt + rust-analyzer ]; + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; shellHook = '' export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0817fce2..ffb34caf 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2982,6 +2982,29 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } ExprKind::Call { func, args, keywords } => { + // Check if call is to a parent method + let mut is_override = false; + if let Some(arg) = args.last() { + if let ExprKind::Name { id, .. } = arg.node { + if id == "self".into() { + is_override = true; + } + } + } + + let mut args = args.clone(); + let (zelf, func_id) = if is_override { + let zelf = args.pop(); + let ExprKind::Constant { value: ast::Constant::Int(func_id), .. } = + args.pop().unwrap().node + else { + unreachable!() + }; + (zelf, Some(func_id)) + } else { + (None, None) + }; + let mut params = args .iter() .map(|arg| generator.gen_expr(ctx, arg)) @@ -3025,14 +3048,21 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; // Handle Class Method calls - let id = if let TypeEnum::TObj { obj_id, .. } = - &*ctx.unifier.get_ty(value.custom.unwrap()) - { + let class_ty = if is_override { + zelf.unwrap().custom.unwrap() + } else { + value.custom.unwrap() + }; + let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(class_ty) { *obj_id } else { unreachable!() }; - let fun_id = { + + // Get function definition + let fun_id = if is_override { + DefinitionId(func_id.unwrap() as usize) + } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 2f0f7e87..4e0bf4b9 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1822,7 +1822,11 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + let all_inited = Self::get_all_assigned_field( + class_name.to_string().into(), + definition_ast_list, + body.as_slice(), + )?; for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9db..209760ae 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -733,7 +734,12 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + #[allow(clippy::only_used_in_recursion)] + pub fn get_all_assigned_field( + class_name: StrRef, + ast: &Vec, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -769,30 +775,77 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field(class_name, ast, body.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_name, + ast, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(class_name, ast, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + class_name, + ast, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(class_name, ast, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + class_name, + ast, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + class_name, + ast, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field(class_name, ast, body.as_slice())?); + } + // Variables Initiated in function calls + ast::StmtKind::Expr { value, .. } => { + let ExprKind::Call { func, .. } = &value.node else { + continue; + }; + let ExprKind::Attribute { value, attr, .. } = &func.node else { + continue; + }; + let ExprKind::Name { id, .. } = &value.node else { + continue; + }; + // Need to conside the two cases: + // Case 1) Call to class function i.e. id = `self` + // Case 2) Call to class ancestor function i.e. id = ancestor_name + // We leave checking whether ancestor is called to type checker + // if *id == "self".into() { + // ast.iter().find_map(|def| { + // let Some(ast::Located { + // node: ast::StmtKind::ClassDef { name, body, .. }, + // .. + // }) = def.1 + // else { + // return None; + // }; + // if *name == class_name {} + // None + // }); + // } } ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + | ast::StmtKind::AnnAssign { .. } => {} _ => { unimplemented!() diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0408cf1c..8091e0d8 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -12,6 +12,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::type_annotation::TypeAnnotation; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -1672,6 +1673,91 @@ impl<'a> Inferencer<'a> { Ok(None) } + /// Checks whether a class method is calling parent function + /// Returns [`None`] if its not a call to parent method, otherwise + /// returns a new `func` with class name replaced by `self` and class name store as `ExprKind::Constant` + /// + /// e.g. A.f1(self, ...) returns Some(self.f1, Some(ExprKind::Constant(A)) + #[allow(clippy::type_complexity)] + fn check_overriding( + &mut self, + func: &ast::Expr<()>, + args: &[ast::Expr<()>], + ) -> Result, Option>>)>, InferenceError> { + // `self` must be first argument for call to parent method + if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() { + if *id != "self".into() { + return Ok(None); + } + } else { + return Ok(None); + } + + let Located { + node: ExprKind::Attribute { value, attr: method_name, ctx }, location, .. + } = func + else { + return Ok(None); + }; + let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else { + return Ok(None); + }; + let zelf = &self.fold_expr(args[0].clone())?; + + // Check whether the method belongs to class ancestors + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let res = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + let res = ancestors.iter().find_map(|f| { + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else { + unreachable!() + }; + // Class names are stored as `__module__.class` + let name = name.to_string(); + let (_, name) = name.split_once('.').unwrap(); + if name == class_name.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *method_name { + return Some(ast::Constant::Int(f.2 .0.try_into().unwrap())); + } + None + }); + } + None + }); + res + } else { + None + } + }; + + match res { + Some(r) => { + let mut new_func = func.clone(); + let mut new_value = value.clone(); + new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx }; + new_func.node = + ExprKind::Attribute { value: new_value, attr: *method_name, ctx: *ctx }; + + let dummy_arg = self.fold_expr(Located { + location: *location, + custom: (), + node: ExprKind::Constant::<()> { value: r, kind: None }, + })?; + + // args.remove (dummy_arg); + Ok(Some((new_func, Some(dummy_arg)))) + } + None => report_error( + format!("Method {class_name}.{method_name} not found in ancestor list").as_str(), + *location, + ), + } + } + fn fold_call( &mut self, location: Location, @@ -1685,8 +1771,21 @@ impl<'a> Inferencer<'a> { return Ok(spec_call_func); } + let mut zelf = None; + + // Check for call to parent method + let override_res = self.check_overriding(&func, &args)?; + let is_override = override_res.is_some(); + let (func, dummy_var) = override_res.unwrap_or((func, None)); + let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + let mut args = + args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + + // Remove self from arguments + if is_override { + zelf = Some(args.remove(0)); + } let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) @@ -1708,6 +1807,13 @@ impl<'a> Inferencer<'a> { self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) })?; + + // Add `class_name` and `self` to arguments for `gen_expr` to generate call to parent method + if let Some(mut arg) = zelf { + arg.node = ExprKind::Name { id: "self".into(), ctx: ExprContext::Load }; + args.push(dummy_var.unwrap()); + args.push(arg); + } return Ok(Located { location, custom: Some(sign.ret), diff --git a/nac3standalone/demo/src/inheritance.py b/nac3standalone/demo/src/inheritance.py index d280e3a5..e71b389c 100644 --- a/nac3standalone/demo/src/inheritance.py +++ b/nac3standalone/demo/src/inheritance.py @@ -6,27 +6,57 @@ def output_int32(x: int32): class A: a: int32 - - def __init__(self, a: int32): - self.a = a + + def __init__(self, param_a: int32): + self.a = param_a def f1(self): - self.f2() + output_int32(12) def f2(self): - output_int32(self.a) + output_int32(124) class B(A): b: int32 + + def __init__(self, param_a: int32, param_b: int32): + self.a = param_a + self.b = param_b + + def f3(self): + output_int32(20) - def __init__(self, b: int32): - self.a = b + 1 + def f1(self): + output_int32(15) + + def f2(self): + self.b = 12 + A.f1(self) + +class C(B): + def __init__(self, a: int32, b: int32): + self.a = a self.b = b - + + def f1(self): + output_int32(17) + + def f3(self): + self.a = 2 + A.f2(self) + + def f4(self): + A.f1(self) + B.f2(self) def run() -> int32: - aaa = A(5) - bbb = B(2) - aaa.f1() - bbb.f1() + c = C(1, 2) + c.f3() + c.f4() + + a = A(1) + + output_int32(c.a) + output_int32(c.b) + return 0