diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 36809e5..e3dc833 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -156,6 +156,71 @@ impl Unifier { self.set_a_to_b(rigid, b); } + pub fn get_instantiations(&mut self, ty: Type) -> Option> { + match &*self.get_ty(ty) { + TypeEnum::TVar { range, .. } => { + let range = range.borrow(); + if range.is_empty() { + None + } else { + Some( + range + .iter() + .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .flatten() + .collect_vec(), + ) + } + } + TypeEnum::TList { ty } => self + .get_instantiations(*ty) + .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()), + TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { + ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() + }), + TypeEnum::TTuple { ty } => { + let tuples = ty + .iter() + .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .multi_cartesian_product() + .collect_vec(); + if tuples.len() == 1 { + None + } else { + Some( + tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(), + ) + } + } + TypeEnum::TObj { params, .. } => { + let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip(); + let params = params + .into_iter() + .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) + .multi_cartesian_product() + .collect_vec(); + if params.len() <= 1 { + None + } else { + Some( + params + .into_iter() + .map(|params| { + self.subst( + ty, + &zip(keys.iter().cloned().cloned(), params.iter().cloned()) + .collect(), + ) + .unwrap_or(ty) + }) + .collect(), + ) + } + } + _ => None, + } + } + pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { use TypeEnum::*; match &*self.get_ty(a) { diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index cf0cc9c..0f9128e 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -2,6 +2,7 @@ use super::*; use itertools::Itertools; use std::collections::HashMap; use test_case::test_case; +use indoc::indoc; impl Unifier { /// Check whether two types are equal. @@ -473,3 +474,50 @@ fn test_rigid_var() { env.unifier.replace_rigid_var(a, int); env.unifier.unify(list_x, list_int).unwrap(); } + +#[test] +fn test_instantiation() { + let mut env = TestEnvironment::new(); + let int = env.parse("int", &HashMap::new()); + let boolean = env.parse("bool", &HashMap::new()); + let float = env.parse("float", &HashMap::new()); + let list_int = env.parse("List[int]", &HashMap::new()); + + let obj_map: HashMap<_, _> = + [(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); + + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v }); + let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0; + let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0; + let t = env.unifier.get_fresh_rigid_var().0; + let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); + let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0; + + let types = env.unifier.get_instantiations(v3).unwrap(); + let expected_types = indoc! {" + tuple[bool, int, float] + tuple[bool, int, list[int]] + tuple[bool, list[bool], float] + tuple[bool, list[bool], list[int]] + tuple[bool, list[int], float] + tuple[bool, list[int], list[int]] + tuple[int, int, float] + tuple[int, int, list[int]] + tuple[int, list[bool], float] + tuple[int, list[bool], list[int]] + tuple[int, list[int], float] + tuple[int, list[int], list[int]] + v5" + }.split('\n').collect_vec(); + let types = types + .iter() + .map(|ty| { + env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| { + format!("v{}", i) + }) + }) + .sorted() + .collect_vec(); + assert_eq!(expected_types, types); +}