From d5fa3fafa182bc5ed8bdb3ef9fcaeef19c20ec7b Mon Sep 17 00:00:00 2001 From: abdul124 Date: Wed, 14 Aug 2024 17:51:52 +0800 Subject: [PATCH] Handle polymorphism with virtual tables --- flake.nix | 2 + nac3core/src/codegen/expr.rs | 26 +++- nac3core/src/toplevel/builtins.rs | 6 + nac3core/src/toplevel/composer.rs | 104 +++++++++++----- nac3core/src/toplevel/helper.rs | 1 + nac3core/src/toplevel/mod.rs | 2 + nac3core/src/typecheck/type_inferencer/mod.rs | 116 +++++++++++++++++- .../src/typecheck/type_inferencer/test.rs | 4 + nac3standalone/demo/interpreted.log | 3 + nac3standalone/demo/run_64.log | 3 + nac3standalone/demo/src/inheritance.py | 45 ++++--- 11 files changed, 263 insertions(+), 49 deletions(-) create mode 100644 nac3standalone/demo/interpreted.log create mode 100644 nac3standalone/demo/run_64.log diff --git a/flake.nix b/flake.nix index 07cea776..7987b259 100644 --- a/flake.nix +++ b/flake.nix @@ -180,7 +180,9 @@ clippy pre-commit rustfmt + rust-analyzer ]; + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; shellHook = '' export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0817fce2..a39b6528 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2982,8 +2982,23 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } ExprKind::Call { func, args, keywords } => { + // Check if expr is override or not + let mut is_override = false; + if let Some(arg) = args.first() { + if let ExprKind::Name { id, .. } = arg.node { + if id == "self".into() { + is_override = true; + } + } + } + let mut args = args.clone(); + if is_override { + args.remove(0); + } + let mut params = args .iter() + .skip(if is_override { 1 } else { 0 }) .map(|arg| generator.gen_expr(ctx, arg)) .take_while(|expr| !matches!(expr, Ok(None))) .map(|expr| Ok((None, expr?.unwrap())) as Result<_, String>) @@ -3035,9 +3050,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let fun_id = { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + let TopLevelDef::Class { methods, virtual_table, .. } = &*obj_def else { + unreachable!() + }; + if is_override { + virtual_table.get(attr).unwrap().1 + } else { + methods.iter().find(|method| method.0 == *attr).unwrap().2 + } }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 2fef5eb4..109d1aa6 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -96,6 +96,7 @@ pub fn get_exn_constructor( fields: exception_fields, attributes: Vec::default(), methods: vec![("__init__".into(), signature, DefinitionId(cons_id))], + virtual_table: HashMap::default(), ancestors: vec![ TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() }, TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }, @@ -689,6 +690,7 @@ impl<'a> BuiltinBuilder<'a> { fields, attributes: Vec::default(), methods: vec![("__init__".into(), ctor_signature, PrimDef::FunRangeInit.id())], + virtual_table: HashMap::default(), ancestors: Vec::default(), constructor: Some(ctor_signature), resolver: None, @@ -821,6 +823,7 @@ impl<'a> BuiltinBuilder<'a> { fields: make_exception_fields(int32, int64, str), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: vec![], constructor: None, resolver: None, @@ -855,6 +858,7 @@ impl<'a> BuiltinBuilder<'a> { Self::create_method(PrimDef::FunOptionIsNone, self.is_some_ty.0), Self::create_method(PrimDef::FunOptionUnwrap, self.unwrap_ty.0), ], + virtual_table: HashMap::default(), ancestors: vec![TypeAnnotation::CustomClass { id: prim.id(), params: Vec::default(), @@ -962,6 +966,7 @@ impl<'a> BuiltinBuilder<'a> { fields: Vec::default(), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: Vec::default(), constructor: None, resolver: None, @@ -990,6 +995,7 @@ impl<'a> BuiltinBuilder<'a> { Self::create_method(PrimDef::FunNDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::FunNDArrayFill, self.ndarray_fill_ty.0), ], + virtual_table: HashMap::default(), ancestors: Vec::default(), constructor: None, resolver: None, diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 2f0f7e87..0f79cab8 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1528,9 +1528,9 @@ impl TopLevelComposer { fn analyze_single_class_ancestors( class_def: &mut TopLevelDef, temp_def_list: &[Arc>], - unifier: &mut Unifier, + _unifier: &mut Unifier, _primitives: &PrimitiveStore, - type_var_to_concrete_def: &mut HashMap, + _type_var_to_concrete_def: &mut HashMap, ) -> Result<(), HashSet> { let TopLevelDef::Class { object_id, @@ -1538,6 +1538,7 @@ impl TopLevelComposer { fields, attributes, methods, + virtual_table, resolver, type_vars, .. @@ -1551,9 +1552,19 @@ impl TopLevelComposer { class_fields_def, class_attribute_def, class_methods_def, + class_virtual_table, _class_type_vars_def, _class_resolver, - ) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver); + ) = ( + *object_id, + ancestors, + fields, + attributes, + methods, + virtual_table, + type_vars, + resolver, + ); // since when this function is called, the ancestors of the direct parent // are supposed to be already handled, so we only need to deal with the direct parent @@ -1564,51 +1575,88 @@ impl TopLevelComposer { let base = temp_def_list.get(id.0).unwrap(); let base = base.read(); - let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else { + let TopLevelDef::Class { methods, virtual_table, fields, attributes, .. } = &*base else { unreachable!("must be top level class def") }; // handle methods override // since we need to maintain the order, create a new list + + // handle methods override + // Since we are following python and its lax syntax, signature is ignored in overriding + // Mark the overrided methods and add them to the child overrides + let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new(); + let mut new_child_virtual_table: HashMap = + virtual_table.clone(); let mut is_override: HashSet = HashSet::new(); + for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { - // find if there is a method with same name in the child class let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); - for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { - if class_method_name == anc_method_name { - // ignore and handle self - // if is __init__ method, no need to check return type - let ok = class_method_name == &"__init__".into() - || Self::check_overload_function_type( - *class_method_ty, - *anc_method_ty, - unifier, - type_var_to_concrete_def, - ); - if !ok { - return Err(HashSet::from([format!( - "method {class_method_name} has same name as ancestors' method, but incompatible type"), - ])); - } - // mark it as added - is_override.insert(*class_method_name); - to_be_added = (*class_method_name, *class_method_ty, *class_method_defid); + for class_method in &*class_methods_def { + if class_method.0 == *anc_method_name { + to_be_added = *class_method; + // Add to virtual table + new_child_virtual_table + .insert(class_method.0, (class_method.1, class_method.2)); + is_override.insert(class_method.0); break; } } new_child_methods.push(to_be_added); } - // add those that are not overriding method to the new_child_methods - for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { - if !is_override.contains(class_method_name) { - new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid)); + for class_method in &*class_methods_def { + if !is_override.contains(&class_method.0) { + new_child_methods.push(*class_method); } } + /* + + Call => class method + super() or class_name A.f1() => method, virtual_tables + + */ + + // for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { + // // find if there is a method with same name in the child class + // let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); + // for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { + // if class_method_name == anc_method_name { + // // ignore and handle self + // // if is __init__ method, no need to check return type + // let ok = class_method_name == &"__init__".into() + // || Self::check_overload_function_type( + // *class_method_ty, + // *anc_method_ty, + // unifier, + // type_var_to_concrete_def, + // ); + // if !ok { + // return Err(HashSet::from([format!( + // "method {class_method_name} has same name as ancestors' method, but incompatible type"), + // ])); + // } + // // mark it as added + // is_override.insert(*class_method_name); + // to_be_added = (*class_method_name, *class_method_ty, *class_method_defid); + // break; + // } + // } + // new_child_methods.push(to_be_added); + // } + // // add those that are not overriding method to the new_child_methods + // for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def { + // if !is_override.contains(class_method_name) { + // new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid)); + // } + // } // use the new_child_methods to replace all the elements in `class_methods_def` class_methods_def.clear(); class_methods_def.extend(new_child_methods); + class_virtual_table.clear(); + class_virtual_table.extend(new_child_virtual_table); + // handle class fields let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new(); // let mut is_override: HashSet<_> = HashSet::new(); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9db..280750d4 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -559,6 +559,7 @@ impl TopLevelComposer { fields: Vec::default(), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: Vec::default(), constructor, resolver, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 7dfd8373..db0d2e16 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -109,6 +109,8 @@ pub enum TopLevelDef { attributes: Vec<(StrRef, Type, ast::Constant)>, /// Class methods, pointing to the corresponding function definition. methods: Vec<(StrRef, Type, DefinitionId)>, + /// Overridden class methods + virtual_table: HashMap, /// Ancestor classes, including itself. ancestors: Vec, /// Symbol resolver of the module defined the class; [None] if it is built-in type. diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0408cf1c..5cc141c4 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -12,6 +12,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::type_annotation::TypeAnnotation; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -1672,6 +1673,86 @@ impl<'a> Inferencer<'a> { Ok(None) } + fn try_overriding( + &mut self, + func: &ast::Expr<()>, + args: &mut [ast::Expr<()>], + ) -> Result, Type)>, InferenceError> { + // Allow Overriding + + // Must have self as first input + if args.is_empty() { + return Ok(None); + } + if let Located { node: ExprKind::Name { id, .. }, .. } = &args[0] { + if *id != "self".into() { + return Ok(None); + } + } else { + return Ok(None); + } + + let Located { + node: ExprKind::Attribute { value, attr: method_name, ctx }, location, .. + } = func + else { + return Ok(None); + }; + let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else { + return Ok(None); + }; + + // Do not Remove self from args (will move it class name instead after activating necessary flags) + let zelf = &self.fold_expr(args[0].clone())?; + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let res = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + let res = ancestors.iter().find_map(|f| { + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else { + unreachable!() + }; + let name = name.to_string(); + let (_, name) = name.split_once('.').unwrap(); + if name == class_name.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *method_name { + return Some(*f); + } + None + }); + } + None + }); + res + } else { + None + } + }; + + if let Some(f) = res { + if let TopLevelDef::Class { virtual_table, .. } = &mut *defs[obj_id.0].write() { + virtual_table.insert(f.0, (f.1, f.2)); + } + } else { + return report_error( + format!("No such function found in parent class {}", method_name).as_str(), + *location, + ); + } + // let Located { node: ExprKind::Attribute { value, attr: method_name, .. }, location, .. } = func + // Change the class name to self to refer to correct part of code + // let new_func = Located { node: ExprKind::Attribute { value, attr: method_name, ctx: () }, location, custom} + let mut new_func = func.clone(); + let mut new_value = value.clone(); + new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx }; + new_func.node = ExprKind::Attribute { value: new_value, attr: *method_name, ctx: *ctx }; + + Ok(Some((new_func, res.unwrap().1))) + } + fn fold_call( &mut self, location: Location, @@ -1685,14 +1766,32 @@ impl<'a> Inferencer<'a> { return Ok(spec_call_func); } + let mut first_arg = None; + let mut is_override = false; + let mut func_sign_key = None; + let override_res = self.try_overriding(&func, &mut args)?; + let func = match override_res { + Some(res) => { + is_override = true; + func_sign_key = Some(res.1); + res.0 + } + None => func, + }; + let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + let mut args = + args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + if is_override { + first_arg = Some(args.remove(0)); + } let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) .collect::, _>>()?; - if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { + let func_key = if is_override { func_sign_key.unwrap() } else { func.custom.unwrap() }; + if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func_key) { if sign.vars.is_empty() { let call = Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), @@ -1705,9 +1804,15 @@ impl<'a> Inferencer<'a> { loc: Some(location), operator_info: None, }; - self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { + self.unifier.unify_call(&call, func_key, sign).map_err(|e| { HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) })?; + + // First parameter is self to indicate override + if let Some(mut arg) = first_arg { + arg.node = ExprKind::Name { id: "self".into(), ctx: ExprContext::Load }; + args.insert(0, arg); + } return Ok(Located { location, custom: Some(sign.ret), @@ -1731,7 +1836,10 @@ impl<'a> Inferencer<'a> { self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call])); self.unify(func.custom.unwrap(), call, &func.location)?; - + println!("Here"); + for k in keywords.iter() { + println!("keyword {}", k.node.value.node.name()); + } Ok(Located { location, custom: Some(ret), node: ExprKind::Call { func, args, keywords } }) } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 75293a3a..7ddfc27d 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -328,6 +328,7 @@ impl TestEnvironment { fields: Vec::default(), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: Vec::default(), resolver: None, constructor: None, @@ -372,6 +373,7 @@ impl TestEnvironment { fields: [("a".into(), tvar.ty, true)].into(), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: Vec::default(), resolver: None, constructor: None, @@ -407,6 +409,7 @@ impl TestEnvironment { fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: Vec::default(), resolver: None, constructor: None, @@ -436,6 +439,7 @@ impl TestEnvironment { fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(), attributes: Vec::default(), methods: Vec::default(), + virtual_table: HashMap::default(), ancestors: Vec::default(), resolver: None, constructor: None, diff --git a/nac3standalone/demo/interpreted.log b/nac3standalone/demo/interpreted.log new file mode 100644 index 00000000..84e96230 --- /dev/null +++ b/nac3standalone/demo/interpreted.log @@ -0,0 +1,3 @@ +12 +12 +17 diff --git a/nac3standalone/demo/run_64.log b/nac3standalone/demo/run_64.log new file mode 100644 index 00000000..23f3f8dc --- /dev/null +++ b/nac3standalone/demo/run_64.log @@ -0,0 +1,3 @@ +12 +12 +15 diff --git a/nac3standalone/demo/src/inheritance.py b/nac3standalone/demo/src/inheritance.py index d280e3a5..8c22b6bd 100644 --- a/nac3standalone/demo/src/inheritance.py +++ b/nac3standalone/demo/src/inheritance.py @@ -6,27 +6,44 @@ def output_int32(x: int32): class A: a: int32 - - def __init__(self, a: int32): - self.a = a + + def __init__(self, param_a: int32): + self.a = param_a def f1(self): - self.f2() - - def f2(self): - output_int32(self.a) + output_int32(12) class B(A): b: int32 - - def __init__(self, b: int32): - self.a = b + 1 + + def __init__(self, param_a: int32, param_b: int32): + self.a = param_a + self.b = param_b + + def f1(self): + output_int32(15) + + def f2(self): + A.f1(self) + self.f1() + +class C(B): + def __init__(self, a: int32, b: int32): + self.a = a self.b = b + + def f1(self): + output_int32(17) + + def f3(self): + B.f2(self) + + def f4(self): + A.f1(self) + def run() -> int32: - aaa = A(5) - bbb = B(2) - aaa.f1() - bbb.f1() + c = B(1, 2) + c.f2() return 0