Merge branch 'hm-inference' into hm-inference_anto

This commit is contained in:
CrescentonC 2021-08-02 11:33:36 +08:00
commit f7bbc3e10d
2 changed files with 119 additions and 0 deletions

View File

@ -156,6 +156,71 @@ impl Unifier {
self.set_a_to_b(rigid, b); self.set_a_to_b(rigid, b);
} }
pub fn get_instantiations(&mut self, ty: Type) -> Option<Vec<Type>> {
match &*self.get_ty(ty) {
TypeEnum::TVar { range, .. } => {
let range = range.borrow();
if range.is_empty() {
None
} else {
Some(
range
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.flatten()
.collect_vec(),
)
}
}
TypeEnum::TList { ty } => self
.get_instantiations(*ty)
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}),
TypeEnum::TTuple { ty } => {
let tuples = ty
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.multi_cartesian_product()
.collect_vec();
if tuples.len() == 1 {
None
} else {
Some(
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
)
}
}
TypeEnum::TObj { params, .. } => {
let (keys, params): (Vec<&u32>, Vec<&Type>) = params.iter().unzip();
let params = params
.into_iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.multi_cartesian_product()
.collect_vec();
if params.len() <= 1 {
None
} else {
Some(
params
.into_iter()
.map(|params| {
self.subst(
ty,
&zip(keys.iter().cloned().cloned(), params.iter().cloned())
.collect(),
)
.unwrap_or(ty)
})
.collect(),
)
}
}
_ => None,
}
}
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {

View File

@ -2,6 +2,7 @@ use super::*;
use itertools::Itertools; use itertools::Itertools;
use std::collections::HashMap; use std::collections::HashMap;
use test_case::test_case; use test_case::test_case;
use indoc::indoc;
impl Unifier { impl Unifier {
/// Check whether two types are equal. /// Check whether two types are equal.
@ -473,3 +474,56 @@ fn test_rigid_var() {
env.unifier.replace_rigid_var(a, int); env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap(); env.unifier.unify(list_x, list_int).unwrap();
} }
#[test]
fn test_instantiation() {
let mut env = TestEnvironment::new();
let int = env.parse("int", &HashMap::new());
let boolean = env.parse("bool", &HashMap::new());
let float = env.parse("float", &HashMap::new());
let list_int = env.parse("List[int]", &HashMap::new());
let obj_map: HashMap<_, _> =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0;
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int]).0;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float]).0;
let t = env.unifier.get_fresh_rigid_var().0;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t]).0;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int)
// v2 = TypeVar('v2', 'list[int]', float)
// v3 = TypeVar('v3', tuple[v, v1, v2], t)
// what values can v3 take?
let types = env.unifier.get_instantiations(v3).unwrap();
let expected_types = indoc! {"
tuple[bool, int, float]
tuple[bool, int, list[int]]
tuple[bool, list[bool], float]
tuple[bool, list[bool], list[int]]
tuple[bool, list[int], float]
tuple[bool, list[int], list[int]]
tuple[int, int, float]
tuple[int, int, list[int]]
tuple[int, list[bool], float]
tuple[int, list[bool], list[int]]
tuple[int, list[int], float]
tuple[int, list[int], list[int]]
v5"
}.split('\n').collect_vec();
let types = types
.iter()
.map(|ty| {
env.unifier.stringify(*ty, &mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| {
format!("v{}", i)
})
})
.sorted()
.collect_vec();
assert_eq!(expected_types, types);
}