forked from M-Labs/nac3
typecheck: fixed recursive substitution
This commit is contained in:
parent
471547855e
commit
180392e2ab
|
@ -794,7 +794,19 @@ impl Unifier {
|
|||
/// If this returns None, the result type would be the original type
|
||||
/// (no substitution has to be done).
|
||||
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||
self.subst_impl(a, mapping, &mut HashMap::new())
|
||||
}
|
||||
|
||||
fn subst_impl(&mut self, a: Type, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Type> {
|
||||
use TypeVarMeta::*;
|
||||
let cached = cache.get_mut(&a);
|
||||
if let Some(cached) = cached {
|
||||
if cached.is_none() {
|
||||
*cached = Some(self.get_fresh_var().0);
|
||||
}
|
||||
return *cached;
|
||||
}
|
||||
|
||||
let ty = self.unification_table.probe_value(a).clone();
|
||||
// this function would only be called when we instantiate functions.
|
||||
// function type signature should ONLY contain concrete types and type
|
||||
|
@ -806,7 +818,7 @@ impl Unifier {
|
|||
TypeEnum::TTuple { ty } => {
|
||||
let mut new_ty = Cow::from(ty);
|
||||
for (i, t) in ty.iter().enumerate() {
|
||||
if let Some(t1) = self.subst(*t, mapping) {
|
||||
if let Some(t1) = self.subst_impl(*t, mapping, cache) {
|
||||
new_ty.to_mut()[i] = t1;
|
||||
}
|
||||
}
|
||||
|
@ -817,10 +829,10 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
TypeEnum::TList { ty } => {
|
||||
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||
}
|
||||
TypeEnum::TVirtual { ty } => {
|
||||
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
||||
}
|
||||
TypeEnum::TObj { obj_id, fields, params } => {
|
||||
// Type variables in field types must be present in the type parameter.
|
||||
|
@ -837,27 +849,32 @@ impl Unifier {
|
|||
}
|
||||
});
|
||||
if need_subst {
|
||||
cache.insert(a, None);
|
||||
let obj_id = *obj_id;
|
||||
let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone());
|
||||
let params = self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
||||
let fields = self
|
||||
.subst_map(&fields.borrow(), mapping)
|
||||
.subst_map(&fields.borrow(), mapping, cache)
|
||||
.unwrap_or_else(|| fields.borrow().clone());
|
||||
Some(self.add_ty(TypeEnum::TObj {
|
||||
let new_ty = self.add_ty(TypeEnum::TObj {
|
||||
obj_id,
|
||||
params: params.into(),
|
||||
fields: fields.into(),
|
||||
}))
|
||||
});
|
||||
if let Some(var) = cache.get(&a).unwrap() {
|
||||
self.unify(new_ty, *var).unwrap();
|
||||
}
|
||||
Some(new_ty)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc(sig) => {
|
||||
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||
let new_params = self.subst_map(params, mapping);
|
||||
let new_ret = self.subst(*ret, mapping);
|
||||
let new_params = self.subst_map(params, mapping, cache);
|
||||
let new_ret = self.subst_impl(*ret, mapping, cache);
|
||||
let mut new_args = Cow::from(args);
|
||||
for (i, t) in args.iter().enumerate() {
|
||||
if let Some(t1) = self.subst(t.ty, mapping) {
|
||||
if let Some(t1) = self.subst_impl(t.ty, mapping, cache) {
|
||||
let mut t = t.clone();
|
||||
t.ty = t1;
|
||||
new_args.to_mut()[i] = t;
|
||||
|
@ -880,13 +897,13 @@ impl Unifier {
|
|||
}
|
||||
}
|
||||
|
||||
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
||||
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Mapping<K>>
|
||||
where
|
||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||
{
|
||||
let mut map2 = None;
|
||||
for (k, v) in map.iter() {
|
||||
if let Some(v1) = self.subst(*v, mapping) {
|
||||
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
|
||||
if map2.is_none() {
|
||||
map2 = Some(map.clone());
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ impl Unifier {
|
|||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
type_mapping: HashMap<String, Type>,
|
||||
pub type_mapping: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
|
@ -325,6 +325,30 @@ fn test_invalid_unification(
|
|||
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_subst() {
|
||||
let mut env = TestEnvironment::new();
|
||||
let int = *env.type_mapping.get("int").unwrap();
|
||||
let foo_id = *env.type_mapping.get("Foo").unwrap();
|
||||
let foo_ty = env.unifier.get_ty(foo_id);
|
||||
let mapping: HashMap<_, _>;
|
||||
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
|
||||
fields.borrow_mut().insert("rec".into(), foo_id);
|
||||
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
|
||||
let instantiated_ty = env.unifier.get_ty(instantiated);
|
||||
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
|
||||
let fields = fields.borrow();
|
||||
assert!(env.unifier.unioned(*fields.get("a").unwrap(), int));
|
||||
assert!(env.unifier.unioned(*fields.get("rec").unwrap(), instantiated));
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_virtual() {
|
||||
let mut env = TestEnvironment::new();
|
||||
|
|
Loading…
Reference in New Issue