diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 882af5a9..4e3a2136 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -3236,6 +3236,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; // Handle Class Method calls + // The attribute will be `DefinitionId` of the method if the call is to one of the parent methods + let func_id = attr.to_string().parse::(); + let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { @@ -3243,7 +3246,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { codegen_unreachable!(ctx) }; - let fun_id = { + + // Use the `DefinitionID` from attribute if it is available + let fun_id = if let Ok(func_id) = func_id { + DefinitionId(func_id) + } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); let TopLevelDef::Class { methods, .. } = &*obj_def else { diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0408cf1c..fc9ce2ec 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -12,6 +12,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::type_annotation::TypeAnnotation; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -102,6 +103,7 @@ pub struct Inferencer<'a> { } type InferenceError = HashSet; +type OverrideResult = Result>>, InferenceError>; struct NaiveFolder(); impl Fold<()> for NaiveFolder { @@ -1672,6 +1674,86 @@ impl<'a> Inferencer<'a> { Ok(None) } + /// Checks whether a class method is calling parent function + /// Returns [`None`] if its not a call to parent method, otherwise + /// returns a new `func` with class name replaced by `self` and method resolved to its `DefinitionID` + /// + /// e.g. A.f1(self, ...) returns Some(self.{DefintionID(f1)}) + fn check_overriding(&mut self, func: &ast::Expr<()>, args: &[ast::Expr<()>]) -> OverrideResult { + // `self` must be first argument for call to parent method + if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() { + if *id != "self".into() { + return Ok(None); + } + } else { + return Ok(None); + } + + let Located { + node: ExprKind::Attribute { value, attr: method_name, ctx }, location, .. + } = func + else { + return Ok(None); + }; + let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else { + return Ok(None); + }; + let zelf = &self.fold_expr(args[0].clone())?; + + // Check whether the method belongs to class ancestors + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let res = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + let res = ancestors.iter().find_map(|f| { + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else { + unreachable!() + }; + // Class names are stored as `__module__.class` + let name = name.to_string(); + let (_, name) = name.rsplit_once('.').unwrap(); + if name == class_name.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *method_name { + return Some(*f); + } + None + }); + } + None + }); + res + } else { + None + } + }; + + match res { + Some(r) => { + let mut new_func = func.clone(); + let mut new_value = value.clone(); + new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx }; + new_func.node = + ExprKind::Attribute { value: new_value.clone(), attr: *method_name, ctx: *ctx }; + + let mut new_func = self.fold_expr(new_func)?; + + let ExprKind::Attribute { value, .. } = new_func.node else { unreachable!() }; + new_func.node = + ExprKind::Attribute { value, attr: r.2 .0.to_string().into(), ctx: *ctx }; + new_func.custom = Some(r.1); + + Ok(Some(new_func)) + } + None => report_error( + format!("Method {class_name}.{method_name} not found in ancestor list").as_str(), + *location, + ), + } + } + fn fold_call( &mut self, location: Location, @@ -1685,8 +1767,20 @@ impl<'a> Inferencer<'a> { return Ok(spec_call_func); } - let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + // Check for call to parent method + let override_res = self.check_overriding(&func, &args)?; + let is_override = override_res.is_some(); + let func = if is_override { override_res.unwrap() } else { self.fold_expr(func)? }; + let func = Box::new(func); + + let mut args = + args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + + // TODO: Handle passing of self to functions to allow runtime lookup of functions to be called + // Currently removing `self` and using compile time function definitions + if is_override { + args.remove(0); + } let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v))