From 40698525030f2b5ab8eceac6eeba29018bdf04ca Mon Sep 17 00:00:00 2001 From: abdul124 Date: Tue, 13 Aug 2024 17:34:00 +0800 Subject: [PATCH] Handle polymorphism as special calls --- nac3core/src/toplevel/composer.rs | 9 +- nac3core/src/toplevel/helper.rs | 157 ++++++++++++++-- nac3core/src/typecheck/type_inferencer/mod.rs | 169 +++++++++++++++++- nac3standalone/demo/interpreted.log | 0 nac3standalone/demo/src/inheritance.py | 34 ++-- nac3standalone/src/basic_symbol_resolver.rs | 2 +- 6 files changed, 333 insertions(+), 38 deletions(-) create mode 100644 nac3standalone/demo/interpreted.log diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 547c7e27..a9ca8784 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1801,7 +1801,12 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + // let all_inited = Self::get_all_assigned_field(body.as_slice())?; + let all_inited = Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?; for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9db..272e718d 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -732,8 +733,11 @@ impl TopLevelComposer { unifier, ) } - - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + pub fn get_all_assigned_field( + definition_ast_list: &Vec, + def: &Arc>, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -769,32 +773,151 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); } - ast::StmtKind::Pass { .. } - | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + // If its a call to __init__function of ancestor extend with ancestor fields + ast::StmtKind::Expr { value, .. } => { + // Check if Expression is a function call to self + if let ExprKind::Call { func, args, .. } = &value.node { + if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node { + let class_def = def.read(); + let (ancestors, methods) = { + let mut class_methods: HashMap = + HashMap::new(); + let mut class_ancestors: HashMap< + StrRef, + HashMap, + > = HashMap::new(); + + if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def { + for m in methods { + class_methods.insert(m.0, m.2); + } + ancestors.iter().skip(1).for_each(|a| { + if let TypeAnnotation::CustomClass { id, .. } = a { + let anc_def = + definition_ast_list.get(id.0).unwrap().0.read(); + if let TopLevelDef::Class { name, methods, .. } = + &*anc_def + { + let mut temp: HashMap = + HashMap::new(); + for m in methods { + temp.insert(m.0, m.2); + } + // Remove module name suffix from name + let mut name_string = name.to_string(); + let split_loc = + name_string.find(|c| c == '.').unwrap() + 1; + class_ancestors.insert( + name_string.split_off(split_loc).into(), + temp, + ); + } + } + }); + } + (class_ancestors, class_methods) + }; + if let ExprKind::Name { id, .. } = value.node { + if id == "self".into() { + // Get Class methods and fields + let method_id = methods.get(fn_name); + if method_id.is_some() { + if let Some(fn_ast) = &definition_ast_list + .get(method_id.unwrap().0) + .unwrap() + .1 + { + if let ast::StmtKind::FunctionDef { body, .. } = + &fn_ast.node + { + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); + } + } + } + } else if let Some(ancestor_methods) = ancestors.get(&id) { + // First arg must be `self` when calling ancestor function + if let ExprKind::Name { id, .. } = args[0].node { + if id == "self".into() { + if let Some(method_id) = ancestor_methods.get(fn_name) { + if let Some(fn_ast) = + &definition_ast_list.get(method_id.0).unwrap().1 + { + if let ast::StmtKind::FunctionDef { + body, .. + } = &fn_ast.node + { + result.extend( + Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?, + ); + } + } + }; + } + } + } + } + } + } + } + ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {} _ => { + println!("{:?}", s.node); unimplemented!() } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 9ac503a1..80a3d9b2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -11,6 +11,7 @@ use super::{ RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap, }, }; +use crate::toplevel::type_annotation::TypeAnnotation; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -1029,7 +1030,97 @@ impl<'a> Inferencer<'a> { keywords: &[Located], ) -> Result>>, InferenceError> { let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else { - return Ok(None); + // Must have self as input + if args.is_empty() { + return Ok(None); + } + + let Located { node: ExprKind::Attribute { value, attr: method_name, ctx }, .. } = func + else { + return Ok(None); + }; + let ExprKind::Name { id: class_name, .. } = &value.node else { return Ok(None) }; + + // Check whether first param is self + let first_arg = args.remove(0); + let Located { node: ExprKind::Name { id: param_name, .. }, .. } = first_arg else { + return Ok(None); + }; + if param_name != "self".into() { + return Ok(None); + } + + // Get Method from ancestors + let zelf = &self.fold_expr(first_arg)?; + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + ancestors.iter().find_map(|f| { + println!("{}", f.stringify(self.unifier)); + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else { + unreachable!() + }; + let name = name.to_string(); + let (_, name) = name.split_once('.').unwrap(); + println!("Comparing against => {name}, {class_name}"); + if name == class_name.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *method_name { + return Some(f.1); + } + None + }); + } + None + }) + } else { + unreachable!() + } + } + .unwrap(); + + let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(result) else { return Ok(None) }; + + let args = args + .iter_mut() + .map(|v| self.fold_expr(v.clone())) + .collect::, _>>()?; + + // let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + // args: vec![FuncArg { + // name: "n".into(), + // ty: arg0.custom.unwrap(), + // default_value: None, + // is_vararg: false, + // }], + // ret, + // vars: VarMap::new(), + // })); + + return Ok(Some(Located { + location, + custom: Some(sign.ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(result), + location: func.location, + node: ExprKind::Attribute { + value: Box::new(Located { + location: func.location, + custom: zelf.custom, + node: ExprKind::Name { id: *class_name, ctx: *ctx }, + }), + attr: *method_name, + ctx: *ctx, + }, + }), + args, + keywords: vec![], + }, + })); }; // handle special functions that cannot be typed in the usual way... @@ -1631,13 +1722,85 @@ impl<'a> Inferencer<'a> { return Ok(spec_call_func); } - let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) .collect::, _>>()?; + println!("==============================="); + println!("=======Printing Func details======="); + println!("Fun Location => {}", func.location); + println!("Fun Node => {}", func.node.name()); + println!("Fun Args => {}", args.len()); + if !args.is_empty() { + println!("First ArgNode => {}", args[0].node.name()); + } + + if let ExprKind::Attribute { value, attr, .. } = &func.node { + println!("Function Attributes"); + println!("Attr Name => {}", attr); + println!("Value node => {}", value.node.name()); + if let ExprKind::Name { id: class_id, .. } = value.node { + println!("Value Node ID => {class_id}"); + + // This ID is the parent class name + // Resolve definition of class from self and get the ancestor list + + let zelf = &self.fold_expr(args[0].clone()).unwrap(); + println!("Unification Key => {}", self.unifier.stringify(zelf.custom.unwrap())); + let def_id = self.unifier.get_ty(zelf.custom.unwrap()); + let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() }; + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() { + ancestors.iter().find_map(|f| { + let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() }; + let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() + else { + unreachable!() + }; + let name = name.to_string(); + let (_, name) = name.split_once('.').unwrap(); + println!("Comparing against => {name}, {class_id}"); + if name == class_id.to_string() { + return methods.iter().find_map(|f| { + if f.0 == *attr { + return Some(f.1); + } + None + }); + } + None + }) + } else { + None + } + } + .unwrap(); + + println!("Function in Selected Parent Class"); + // Construct new call add type checking later if it works + let args = args + .iter() + .map(|v| self.fold_expr(v.clone())) + .collect::, _>>()?; + // let func = Box::new(self.fold_expr(func.clone()).unwrap()); + // let ty = self.unifier.get_ty(result); + println!("Function Type => {}", self.unifier.stringify(result)); + + // Now I have the unification key of the call + // and vars for the call + // Need to make call + // Use special case for ref + + // let expr = ExprKind::Attribute { value: (), attr: (), ctx: () } + println!("======================"); + } + } + println!("=======Ending Func details======="); + + let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + let func = Box::new(self.fold_expr(func)?); if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { if sign.vars.is_empty() { let call = Call { diff --git a/nac3standalone/demo/interpreted.log b/nac3standalone/demo/interpreted.log new file mode 100644 index 00000000..e69de29b diff --git a/nac3standalone/demo/src/inheritance.py b/nac3standalone/demo/src/inheritance.py index d280e3a5..54264c7c 100644 --- a/nac3standalone/demo/src/inheritance.py +++ b/nac3standalone/demo/src/inheritance.py @@ -6,27 +6,31 @@ def output_int32(x: int32): class A: a: int32 - - def __init__(self, a: int32): - self.a = a + def __init__(self, val: int32): + self.a = val + # self.f1() def f1(self): - self.f2() - - def f2(self): output_int32(self.a) class B(A): b: int32 - - def __init__(self, b: int32): - self.a = b + 1 - self.b = b - + def __init__(self, val1: int32, val2: int32): + A.__init__(self, val1) + self.b = val2 + + def f2(self): + # A.f1(self) + output_int32(self.b) def run() -> int32: - aaa = A(5) - bbb = B(2) - aaa.f1() - bbb.f1() + c1 = B(2, 4) + # c1.f2() + + + + # aaa = A(5) + # bbb = B(2) + # aaa.f1() + # bbb.f1() return 0 diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 5fe0d4f5..48f6cb0e 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -59,7 +59,7 @@ impl SymbolResolver for Resolver { _: StrRef, _: &mut CodeGenContext<'ctx, '_>, ) -> Option> { - unimplemented!() + None } fn get_identifier_def(&self, id: StrRef) -> Result> {