From a0662c58e6a12d818664076b225cd62d813ea982 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Tue, 14 Sep 2021 22:49:20 +0800 Subject: [PATCH] nac3core: fix recursive top level function call --- nac3core/src/toplevel/composer.rs | 85 +++++++++++++++++-------------- nac3core/src/toplevel/test.rs | 78 ++++++++++++++++++---------- 2 files changed, 99 insertions(+), 64 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 62ecc247..0fa5befd 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -77,7 +77,10 @@ impl TopLevelComposer { ) .into(), // FIXME: all the big unifier or? - unifiers: Arc::new(RwLock::new(vec![(self.unifier.get_shared_unifier(), self.primitives_ty)])), + unifiers: Arc::new(RwLock::new(vec![( + self.unifier.get_shared_unifier(), + self.primitives_ty, + )])), } } @@ -92,7 +95,7 @@ impl TopLevelComposer { ast: ast::Stmt<()>, resolver: Option>>, mod_path: String, - ) -> Result<(String, DefinitionId), String> { + ) -> Result<(String, DefinitionId, Option), String> { let defined_class_name = &mut self.defined_class_name; let defined_class_method_name = &mut self.defined_class_method_name; let defined_function_name = &mut self.defined_function_name; @@ -212,7 +215,7 @@ impl TopLevelComposer { None, )); - Ok((class_name, DefinitionId(class_def_id))) + Ok((class_name, DefinitionId(class_def_id), None)) } ast::StmtKind::FunctionDef { name, .. } => { @@ -228,12 +231,13 @@ impl TopLevelComposer { return Err("duplicate top level function define".into()); } + let ty_to_be_unified = self.unifier.get_fresh_var().0; // add to the definition list self.definition_ast_list.push(( RwLock::new(Self::make_top_level_function_def( name.into(), // dummy here, unify with correct type later - self.unifier.get_fresh_var().0, + ty_to_be_unified, resolver, )) .into(), @@ -241,7 +245,11 @@ impl TopLevelComposer { )); // return - Ok((fun_name, DefinitionId(self.definition_ast_list.len() - 1))) + Ok(( + fun_name, + DefinitionId(self.definition_ast_list.len() - 1), + Some(ty_to_be_unified), + )) } _ => Err("only registrations of top level classes/functions are supprted".into()), @@ -1111,7 +1119,6 @@ impl TopLevelComposer { /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function fn analyze_function_instance(&mut self) -> Result<(), String> { for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() { - let mut function_def = def.write(); if let TopLevelDef::Function { instance_to_stmt, @@ -1120,7 +1127,8 @@ impl TopLevelComposer { var_id, resolver, .. - } = &mut *function_def { + } = &mut *function_def + { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { let FunSignature { args, ret, vars } = &*func_sig.borrow(); // None if is not class method @@ -1134,7 +1142,7 @@ impl TopLevelComposer { self.extract_def_list().as_slice(), &mut self.unifier, &self.primitives_ty, - &ty_ann + &ty_ann, )?) } else { unreachable!("must be class def") @@ -1145,12 +1153,12 @@ impl TopLevelComposer { }; let type_var_subst_comb = { let unifier = &mut self.unifier; - let var_ids = vars - .iter() - .map(|(id, _)| *id); + let var_ids = vars.iter().map(|(id, _)| *id); let var_combs = vars .iter() - .map(|(_, ty)| unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .map(|(_, ty)| { + unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]) + }) .multi_cartesian_product() .collect_vec(); let mut result: Vec> = Default::default(); @@ -1173,16 +1181,16 @@ impl TopLevelComposer { .map(|a| FuncArg { name: a.name.clone(), ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), - default_value: a.default_value.clone() + default_value: a.default_value.clone(), }) .collect_vec(); let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)); - + let mut identifiers = { // NOTE: none and function args? let mut result: HashSet = HashSet::new(); result.insert("None".into()); - if self_type.is_some(){ + if self_type.is_some() { result.insert("self".into()); } result.extend(inst_args.iter().map(|x| x.name.clone())); @@ -1194,7 +1202,10 @@ impl TopLevelComposer { defined_identifiers: identifiers.clone(), function_data: &mut FunctionData { resolver: resolver.as_ref().unwrap().clone(), - return_type: if self.unifier.unioned(inst_ret, self.primitives_ty.none) { + return_type: if self + .unifier + .unioned(inst_ret, self.primitives_ty.none) + { None } else { Some(inst_ret) @@ -1219,28 +1230,29 @@ impl TopLevelComposer { } }; - let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = ast.clone().unwrap().node { - body - } else { - unreachable!("must be function def ast") - } - .into_iter() - .map(|b| inferencer.fold_stmt(b)) - .collect::, _>>()?; - - let returned = inferencer - .check_block(fun_body.as_slice(), &mut identifiers)?; - + let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = + ast.clone().unwrap().node + { + body + } else { + unreachable!("must be function def ast") + } + .into_iter() + .map(|b| inferencer.fold_stmt(b)) + .collect::, _>>()?; + + let returned = + inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; + if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned { let ret_str = self.unifier.stringify( - inst_ret, + inst_ret, &mut |id| format!("class{}", id), - &mut |id| format!("tvar{}", id) + &mut |id| format!("tvar{}", id), ); return Err(format!( "expected return type of {} in function `{}`", - ret_str, - name + ret_str, name )); } @@ -1251,17 +1263,16 @@ impl TopLevelComposer { body: fun_body, unifier_id: 0, calls: HashMap::new(), - subst - } + subst, + }, ); - } + } } else { unreachable!("must be typeenum::tfunc") } } else { - continue + continue; } - } Ok(()) } diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 6e39ce47..504ee6b7 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -25,12 +25,17 @@ impl ResolverInternal { fn add_id_def(&self, id: String, def: DefinitionId) { self.id_to_def.lock().insert(id, def); } + + fn add_id_type(&self, id: String, ty: Type) { + self.id_to_type.lock().insert(id, ty); + } } struct Resolver(Arc); impl SymbolResolver for Resolver { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { + println!("unkonw here resolver {}", str); self.0.id_to_type.lock().get(str).cloned() } @@ -133,9 +138,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s let ast = parse_program(s).unwrap(); let ast = ast[0].clone(); - let (id, def_id) = + let (id, def_id, ty) = composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap(); - internal_resolver.add_id_def(id, def_id); + internal_resolver.add_id_def(id.clone(), def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(id, ty); + } } composer.start_analysis(true).unwrap(); @@ -786,7 +794,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { let ast = parse_program(s).unwrap(); let ast = ast[0].clone(); - let (id, def_id) = { + let (id, def_id, ty) = { match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { Ok(x) => x, Err(msg) => { @@ -799,7 +807,10 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { } } }; - internal_resolver.add_id_def(id, def_id); + internal_resolver.add_id_def(id.clone(), def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(id, ty); + } } if let Err(msg) = composer.start_analysis(false) { @@ -846,6 +857,14 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { indoc! {" def fun(a: int32, b: int32) -> int32: return a + b + "}, + indoc! {" + def fib(n: int32) -> int32: + if n <= 2: + return 1 + a = fib(n - 1) + b = fib(n - 2) + return fib(n - 1) "} ], vec![]; @@ -861,14 +880,24 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { def fun(self) -> int32: b = self.a + 3 return b * self.a - def dup(self) -> A: + def clone(self) -> A: SELF = self return SELF - + def sum(self) -> int32: + if self.a == 0: + return self.a + else: + a = self.a + self.a = self.a - 1 + return a + self.sum() + def fib(self, a: int32) -> int32: + if a <= 2: + return 1 + return self.fib(a - 1) + self.fib(a - 2) "}, indoc! {" def fun(a: A) -> int32: - return a.fun() + return a.fun() + 2 "} ], vec![]; @@ -878,21 +907,9 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { let print = true; let mut composer = TopLevelComposer::new(); - let tvar_t = composer.unifier.get_fresh_var(); - let tvar_v = composer - .unifier - .get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]); - - if print { - println!("t: {}, {:?}", tvar_t.1, tvar_t.0); - println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0); - } - let internal_resolver = Arc::new(ResolverInternal { id_to_def: Default::default(), - id_to_type: Mutex::new( - vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(), - ), + id_to_type: Default::default(), class_names: Default::default(), }); let resolver = Arc::new( @@ -903,7 +920,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { let ast = parse_program(s).unwrap(); let ast = ast[0].clone(); - let (id, def_id) = { + let (id, def_id, ty) = { match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { Ok(x) => x, Err(msg) => { @@ -916,12 +933,14 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { } } }; - internal_resolver.add_id_def(id, def_id); + internal_resolver.add_id_def(id.clone(), def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(id, ty); + } } - + if let Err(msg) = composer.start_analysis(true) { if print { - // println!("err2:"); println!("{}", msg); } else { assert_eq!(res[0], msg); @@ -931,12 +950,17 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { let def = &*def.read(); - if let TopLevelDef::Function { instance_to_stmt, .. } = def { + if let TopLevelDef::Function { instance_to_stmt, name, .. } = def { for inst in instance_to_stmt.iter() { let ast = &inst.1.body; - println!("{:?}", ast) + println!("{}:", name); + for b in ast { + println!("{:?}", b); + println!("--------------------"); + } + println!("\n"); } } } } -} \ No newline at end of file +}