forked from M-Labs/nac3
added method to get all instantiations
This commit is contained in:
parent
7ad8e2d81d
commit
eba92ed8bd
|
@ -156,6 +156,71 @@ impl Unifier {
|
|||
self.set_a_to_b(rigid, b);
|
||||
}
|
||||
|
||||
pub fn get_instantiations(&mut self, ty: Type) -> Option<Vec<Type>> {
|
||||
match &*self.get_ty(ty) {
|
||||
TypeEnum::TVar { range, .. } => {
|
||||
let range = range.borrow();
|
||||
if range.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
range
|
||||
.iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.flatten()
|
||||
.collect_vec(),
|
||||
)
|
||||
}
|
||||
}
|
||||
TypeEnum::TList { ty } => self
|
||||
.get_instantiations(*ty)
|
||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
||||
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
||||
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
||||
}),
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let tuples = ty
|
||||
.iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
if tuples.len() == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { params, .. } => {
|
||||
let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip();
|
||||
let params = params
|
||||
.into_iter()
|
||||
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
if params.len() <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
params
|
||||
.into_iter()
|
||||
.map(|params| {
|
||||
self.subst(
|
||||
ty,
|
||||
&zip(keys.iter().cloned().cloned(), params.iter().cloned())
|
||||
.collect(),
|
||||
)
|
||||
.unwrap_or(ty)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
|
||||
use TypeEnum::*;
|
||||
match &*self.get_ty(a) {
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::*;
|
|||
use itertools::Itertools;
|
||||
use std::collections::HashMap;
|
||||
use test_case::test_case;
|
||||
use indoc::indoc;
|
||||
|
||||
impl Unifier {
|
||||
/// Check whether two types are equal.
|
||||
|
@ -473,3 +474,50 @@ fn test_rigid_var() {
|
|||
env.unifier.replace_rigid_var(a, int);
|
||||
env.unifier.unify(list_x, list_int).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_instantiation() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let int = env.parse("int", &HashMap::new());
|
||||
let boolean = env.parse("bool", &HashMap::new());
|
||||
let float = env.parse("float", &HashMap::new());
|
||||
let list_int = env.parse("List[int]", &HashMap::new());
|
||||
|
||||
let obj_map: HashMap<_, _> =
|
||||
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
|
||||
|
||||
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
|
||||
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
|
||||
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0;
|
||||
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0;
|
||||
let t = env.unifier.get_fresh_rigid_var().0;
|
||||
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
|
||||
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0;
|
||||
|
||||
let types = env.unifier.get_instantiations(v3).unwrap();
|
||||
let expected_types = indoc! {"
|
||||
tuple[bool, int, float]
|
||||
tuple[bool, int, list[int]]
|
||||
tuple[bool, list[bool], float]
|
||||
tuple[bool, list[bool], list[int]]
|
||||
tuple[bool, list[int], float]
|
||||
tuple[bool, list[int], list[int]]
|
||||
tuple[int, int, float]
|
||||
tuple[int, int, list[int]]
|
||||
tuple[int, list[bool], float]
|
||||
tuple[int, list[bool], list[int]]
|
||||
tuple[int, list[int], float]
|
||||
tuple[int, list[int], list[int]]
|
||||
v5"
|
||||
}.split('\n').collect_vec();
|
||||
let types = types
|
||||
.iter()
|
||||
.map(|ty| {
|
||||
env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| {
|
||||
format!("v{}", i)
|
||||
})
|
||||
})
|
||||
.sorted()
|
||||
.collect_vec();
|
||||
assert_eq!(expected_types, types);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue