diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index e11f17b19..62ecc247e 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,3 +1,7 @@ +use rustpython_parser::ast::fold::Fold; + +use crate::typecheck::type_inferencer::{FunctionData, Inferencer}; + use super::*; type DefAst = (Arc>, Option>); @@ -14,6 +18,8 @@ pub struct TopLevelComposer { pub defined_class_name: HashSet, pub defined_class_method_name: HashSet, pub defined_function_name: HashSet, + // get the class def id of a class method + pub method_class: HashMap, } impl Default for TopLevelComposer { @@ -60,13 +66,14 @@ impl TopLevelComposer { defined_class_method_name: Default::default(), defined_class_name: Default::default(), defined_function_name: Default::default(), + method_class: Default::default(), } } - pub fn make_top_level_context(self) -> TopLevelContext { + pub fn make_top_level_context(&self) -> TopLevelContext { TopLevelContext { definitions: RwLock::new( - self.definition_ast_list.into_iter().map(|(x, ..)| x).collect_vec(), + self.definition_ast_list.iter().map(|(x, ..)| x.clone()).collect_vec(), ) .into(), // FIXME: all the big unifier or? @@ -186,7 +193,8 @@ impl TopLevelComposer { for (name, _, id, ty, ..) in &class_method_name_def_ids { let mut class_def = class_def_ast.0.write(); if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() { - methods.push((name.clone(), *ty, *id)) + methods.push((name.clone(), *ty, *id)); + self.method_class.insert(*id, DefinitionId(class_def_id)); } else { unreachable!() } @@ -240,11 +248,14 @@ impl TopLevelComposer { } } - pub fn start_analysis(&mut self) -> Result<(), String> { + pub fn start_analysis(&mut self, inference: bool) -> Result<(), String> { self.analyze_top_level_class_type_var()?; self.analyze_top_level_class_bases()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; + if inference { + self.analyze_function_instance()?; + } Ok(()) } @@ -1096,4 +1107,162 @@ impl TopLevelComposer { Ok(()) } + + /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function + fn analyze_function_instance(&mut self) -> Result<(), String> { + for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() { + + let mut function_def = def.write(); + if let TopLevelDef::Function { + instance_to_stmt, + name, + signature, + var_id, + resolver, + .. + } = &mut *function_def { + if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { + let FunSignature { args, ret, vars } = &*func_sig.borrow(); + // None if is not class method + let self_type = { + if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { + let class_def = self.definition_ast_list.get(class_id.0).unwrap(); + let class_def = class_def.0.read(); + if let TopLevelDef::Class { type_vars, .. } = &*class_def { + let ty_ann = make_self_type_annotation(type_vars, *class_id); + Some(get_type_from_type_annotation_kinds( + self.extract_def_list().as_slice(), + &mut self.unifier, + &self.primitives_ty, + &ty_ann + )?) + } else { + unreachable!("must be class def") + } + } else { + None + } + }; + let type_var_subst_comb = { + let unifier = &mut self.unifier; + let var_ids = vars + .iter() + .map(|(id, _)| *id); + let var_combs = vars + .iter() + .map(|(_, ty)| unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .multi_cartesian_product() + .collect_vec(); + let mut result: Vec> = Default::default(); + for comb in var_combs { + result.push(var_ids.clone().zip(comb).collect()); + } + // NOTE: if is empty, means no type var, append a empty subst, ok to do this? + if result.is_empty() { + result.push(HashMap::new()) + } + result + }; + + for subst in type_var_subst_comb { + // for each instance + let unifier = &mut self.unifier; + let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); + let inst_args = args + .iter() + .map(|a| FuncArg { + name: a.name.clone(), + ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), + default_value: a.default_value.clone() + }) + .collect_vec(); + let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)); + + let mut identifiers = { + // NOTE: none and function args? + let mut result: HashSet = HashSet::new(); + result.insert("None".into()); + if self_type.is_some(){ + result.insert("self".into()); + } + result.extend(inst_args.iter().map(|x| x.name.clone())); + result + }; + let mut inferencer = { + Inferencer { + top_level: &self.make_top_level_context(), + defined_identifiers: identifiers.clone(), + function_data: &mut FunctionData { + resolver: resolver.as_ref().unwrap().clone(), + return_type: if self.unifier.unioned(inst_ret, self.primitives_ty.none) { + None + } else { + Some(inst_ret) + }, + // NOTE: allowed type vars: leave blank? + bound_variables: Vec::new(), + }, + unifier: &mut self.unifier, + variable_mapping: { + // NOTE: none and function args? + let mut result: HashMap = HashMap::new(); + result.insert("None".into(), self.primitives_ty.none); + if let Some(self_ty) = self_type { + result.insert("self".into(), self_ty); + } + result.extend(inst_args.iter().map(|x| (x.name.clone(), x.ty))); + result + }, + primitives: &self.primitives_ty, + virtual_checks: &mut Vec::new(), + calls: &mut HashMap::new(), + } + }; + + let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = ast.clone().unwrap().node { + body + } else { + unreachable!("must be function def ast") + } + .into_iter() + .map(|b| inferencer.fold_stmt(b)) + .collect::, _>>()?; + + let returned = inferencer + .check_block(fun_body.as_slice(), &mut identifiers)?; + + if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned { + let ret_str = self.unifier.stringify( + inst_ret, + &mut |id| format!("class{}", id), + &mut |id| format!("tvar{}", id) + ); + return Err(format!( + "expected return type of {} in function `{}`", + ret_str, + name + )); + } + + instance_to_stmt.insert( + // FIXME: how? + "".to_string(), + FunInstance { + body: fun_body, + unifier_id: 0, + calls: HashMap::new(), + subst + } + ); + } + } else { + unreachable!("must be typeenum::tfunc") + } + } else { + continue + } + + } + Ok(()) + } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 0518fffe2..206aefe9c 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -36,7 +36,7 @@ pub struct FunInstance { pub unifier_id: usize, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum TopLevelDef { Class { // name for error messages and symbols diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 0f0c14337..6e39ce471 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -138,7 +138,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s internal_resolver.add_id_def(id, def_id); } - composer.start_analysis().unwrap(); + composer.start_analysis(true).unwrap(); for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { let def = &*def.read(); @@ -802,7 +802,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { internal_resolver.add_id_def(id, def_id); } - if let Err(msg) = composer.start_analysis() { + if let Err(msg) = composer.start_analysis(false) { if print { println!("{}", msg); } else { @@ -840,3 +840,103 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { } } } + +#[test_case( + vec![ + indoc! {" + def fun(a: int32, b: int32) -> int32: + return a + b + "} + ], + vec![]; + "simple function" +)] +#[test_case( + vec![ + indoc! {" + class A: + a: int32 + def __init__(self): + self.a = 3 + def fun(self) -> int32: + b = self.a + 3 + return b * self.a + def dup(self) -> A: + SELF = self + return SELF + + "}, + indoc! {" + def fun(a: A) -> int32: + return a.fun() + "} + ], + vec![]; + "simple class body" +)] +fn test_inference(source: Vec<&str>, res: Vec<&str>) { + let print = true; + let mut composer = TopLevelComposer::new(); + + let tvar_t = composer.unifier.get_fresh_var(); + let tvar_v = composer + .unifier + .get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]); + + if print { + println!("t: {}, {:?}", tvar_t.1, tvar_t.0); + println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0); + } + + let internal_resolver = Arc::new(ResolverInternal { + id_to_def: Default::default(), + id_to_type: Mutex::new( + vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(), + ), + class_names: Default::default(), + }); + let resolver = Arc::new( + Box::new(Resolver(internal_resolver.clone())) as Box + ); + + for s in source { + let ast = parse_program(s).unwrap(); + let ast = ast[0].clone(); + + let (id, def_id) = { + match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { + Ok(x) => x, + Err(msg) => { + if print { + println!("{}", msg); + } else { + assert_eq!(res[0], msg); + } + return; + } + } + }; + internal_resolver.add_id_def(id, def_id); + } + + if let Err(msg) = composer.start_analysis(true) { + if print { + // println!("err2:"); + println!("{}", msg); + } else { + assert_eq!(res[0], msg); + } + } else { + // skip 5 to skip primitives + for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { + let def = &*def.read(); + + if let TopLevelDef::Function { instance_to_stmt, .. } = def { + for inst in instance_to_stmt.iter() { + let ast = &inst.1.body; + println!("{:?}", ast) + } + } + } + } +} \ No newline at end of file diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 9d6ebeeb8..18cc880eb 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -237,7 +237,7 @@ pub fn get_type_from_type_annotation_kinds( let subst = { // check for compatible range - // TODO: if allow type var to be applied, need more check + // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check let mut result: HashMap = HashMap::new(); for (tvar, p) in type_vars.iter().zip(param_ty) { if let TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic } =