diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 0fa5bef..d6b528c 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1151,9 +1151,16 @@ impl TopLevelComposer { None } }; - let type_var_subst_comb = { + let (type_var_subst_comb, no_range_vars) = { let unifier = &mut self.unifier; - let var_ids = vars.iter().map(|(id, _)| *id); + let mut no_ranges: Vec = Vec::new(); + let var_ids = vars.iter().map(|(id, ty)| { + if matches!(unifier.get_ty(*ty).as_ref(), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) { + no_ranges.push(*ty); + } + *id + }) + .collect_vec(); let var_combs = vars .iter() .map(|(_, ty)| { @@ -1163,28 +1170,33 @@ impl TopLevelComposer { .collect_vec(); let mut result: Vec> = Default::default(); for comb in var_combs { - result.push(var_ids.clone().zip(comb).collect()); + result.push(var_ids.clone().into_iter().zip(comb).collect()); } // NOTE: if is empty, means no type var, append a empty subst, ok to do this? if result.is_empty() { result.push(HashMap::new()) } - result + (result, no_ranges) }; for subst in type_var_subst_comb { // for each instance - let unifier = &mut self.unifier; - let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); - let inst_args = args - .iter() - .map(|a| FuncArg { - name: a.name.clone(), - ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), - default_value: a.default_value.clone(), - }) - .collect_vec(); - let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)); + let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret); + let inst_args = { + let unifier = &mut self.unifier; + args + .iter() + .map(|a| FuncArg { + name: a.name.clone(), + ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), + default_value: a.default_value.clone(), + }) + .collect_vec() + }; + let self_type = { + let unifier = &mut self.unifier; + self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)) + }; let mut identifiers = { // NOTE: none and function args? @@ -1196,38 +1208,36 @@ impl TopLevelComposer { result.extend(inst_args.iter().map(|x| x.name.clone())); result }; - let mut inferencer = { - Inferencer { - top_level: &self.make_top_level_context(), - defined_identifiers: identifiers.clone(), - function_data: &mut FunctionData { - resolver: resolver.as_ref().unwrap().clone(), - return_type: if self - .unifier - .unioned(inst_ret, self.primitives_ty.none) - { - None - } else { - Some(inst_ret) - }, - // NOTE: allowed type vars: leave blank? - bound_variables: Vec::new(), + let mut inferencer = Inferencer { + top_level: &self.make_top_level_context(), + defined_identifiers: identifiers.clone(), + function_data: &mut FunctionData { + resolver: resolver.as_ref().unwrap().clone(), + return_type: if self + .unifier + .unioned(inst_ret, self.primitives_ty.none) + { + None + } else { + Some(inst_ret) }, - unifier: &mut self.unifier, - variable_mapping: { - // NOTE: none and function args? - let mut result: HashMap = HashMap::new(); - result.insert("None".into(), self.primitives_ty.none); - if let Some(self_ty) = self_type { - result.insert("self".into(), self_ty); - } - result.extend(inst_args.iter().map(|x| (x.name.clone(), x.ty))); - result - }, - primitives: &self.primitives_ty, - virtual_checks: &mut Vec::new(), - calls: &mut HashMap::new(), - } + // NOTE: allowed type vars + bound_variables: no_range_vars.clone(), + }, + unifier: &mut self.unifier, + variable_mapping: { + // NOTE: none and function args? + let mut result: HashMap = HashMap::new(); + result.insert("None".into(), self.primitives_ty.none); + if let Some(self_ty) = self_type { + result.insert("self".into(), self_ty); + } + result.extend(inst_args.iter().map(|x| (x.name.clone(), x.ty))); + result + }, + primitives: &self.primitives_ty, + virtual_checks: &mut Vec::new(), + calls: &mut HashMap::new(), }; let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = @@ -1257,8 +1267,17 @@ impl TopLevelComposer { } instance_to_stmt.insert( - // FIXME: how? - "".to_string(), + // NOTE: refer to codegen/expr/get_subst_key function + { + let unifier = &mut self.unifier; + subst + .keys() + .sorted() + .map(|id| { + let ty = subst.get(id).unwrap(); + unifier.stringify(*ty, &mut |id| id.to_string(), &mut |id| id.to_string()) + }).join(", ") + }, FunInstance { body: fun_body, unifier_id: 0, diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index 2dfda3f..6c749b1 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -37,7 +37,7 @@ impl SymbolResolver for Resolver { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { let ret = self.0.id_to_type.lock().get(str).cloned(); if ret.is_none() { - println!("unknown here resolver {}", str); + // println!("unknown here resolver {}", str); } ret } @@ -772,23 +772,15 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { let print = false; let mut composer = TopLevelComposer::new(); - let tvar_t = composer.unifier.get_fresh_var(); - let tvar_v = composer - .unifier - .get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]); - - if print { - println!("t: {}, {:?}", tvar_t.1, tvar_t.0); - println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0); - } - - let internal_resolver = Arc::new(ResolverInternal { - id_to_def: Default::default(), - id_to_type: Mutex::new( - vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(), - ), - class_names: Default::default(), - }); + let internal_resolver = make_internal_resolver_with_tvar( + vec![ + ("T".into(), vec![]), + ("V".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int32]), + ("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]), + ], + &mut composer.unifier, + print + ); let resolver = Arc::new( Box::new(Resolver(internal_resolver.clone())) as Box ); @@ -888,7 +880,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { return SELF def sum(self) -> int32: if self.a == 0: - return self.a + return self.a + self else: a = self.a self.a = self.a - 1 @@ -909,34 +901,58 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) { #[test_case( vec![ indoc! {" - def fun(a: V) -> V: + def fun(a: V, c: G, t: T) -> V: b = a - return a + cc = c + ret = fun(b, cc, t) + return ret * ret + "}, + indoc! {" + def sum3(l: list[V]) -> V: + return l[0] + l[1] + l[2] + "}, + indoc! {" + def sum_sq_pair(p: tuple[V, V]) -> list[V]: + a = p[0] + b = p[1] + a = a**a + b = b**b + return [a, b] "} ], vec![]; "type var fun" )] +#[test_case( + vec![ + indoc! {" + class A(Generic[G]): + a: G + b: bool + def __init__(self, aa: G): + self.a = aa + self.b = True + def fun(self, a: G) -> list[G]: + ret = [a, self.a] + return ret if self.b else self.fun(self.a) + "} + ], + vec![]; + "type var class" +)] fn test_inference(source: Vec<&str>, res: Vec<&str>) { let print = true; let mut composer = TopLevelComposer::new(); - let tvar_t = composer.unifier.get_fresh_var(); - let tvar_v = composer - .unifier - .get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int64]); - if print { - println!("t: {}, {:?}", tvar_t.1, tvar_t.0); - println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0); - } - - let internal_resolver = Arc::new(ResolverInternal { - id_to_def: Default::default(), - id_to_type: Mutex::new( - vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(), - ), - class_names: Default::default(), - }); + let internal_resolver = make_internal_resolver_with_tvar( + vec![ + ("T".into(), vec![]), + ("V".into(), vec![composer.primitives_ty.float, composer.primitives_ty.int32, composer.primitives_ty.int64]), + ("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]), + ], + &mut composer.unifier, + print + ); let resolver = Arc::new( Box::new(Resolver(internal_resolver.clone())) as Box ); @@ -977,9 +993,9 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { let def = &*def.read(); if let TopLevelDef::Function { instance_to_stmt, name, .. } = def { + println!("=========`{}`: number of instances: {}===========", name, instance_to_stmt.len()); for inst in instance_to_stmt.iter() { let ast = &inst.1.body; - println!("{}:", name); for b in ast { println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap()); println!("--------------------"); @@ -991,6 +1007,31 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) { } } +fn make_internal_resolver_with_tvar(tvars: Vec<(String, Vec)>, unifier: &mut Unifier, print: bool) -> Arc { + let res: Arc = ResolverInternal { + id_to_def: Default::default(), + id_to_type: tvars + .into_iter() + .map(|(name, range)| ( + name.clone(), + { + let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice()); + if print { + println!("{}: {:?}, tvar{}", name, ty, id); + } + ty + } + )) + .collect::>() + .into(), + class_names: Default::default() + }.into(); + if print { + println!(); + } + res +} + struct TypeToStringFolder<'a> { unifier: &'a mut Unifier }