typecheck: fixed recursive substitution

This commit is contained in:
pca006132 2021-09-12 21:33:21 +08:00
parent 471547855e
commit 180392e2ab
2 changed files with 54 additions and 13 deletions

View File

@ -794,7 +794,19 @@ impl Unifier {
/// If this returns None, the result type would be the original type /// If this returns None, the result type would be the original type
/// (no substitution has to be done). /// (no substitution has to be done).
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> { pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
self.subst_impl(a, mapping, &mut HashMap::new())
}
fn subst_impl(&mut self, a: Type, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Type> {
use TypeVarMeta::*; use TypeVarMeta::*;
let cached = cache.get_mut(&a);
if let Some(cached) = cached {
if cached.is_none() {
*cached = Some(self.get_fresh_var().0);
}
return *cached;
}
let ty = self.unification_table.probe_value(a).clone(); let ty = self.unification_table.probe_value(a).clone();
// this function would only be called when we instantiate functions. // this function would only be called when we instantiate functions.
// function type signature should ONLY contain concrete types and type // function type signature should ONLY contain concrete types and type
@ -806,7 +818,7 @@ impl Unifier {
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty); let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() { for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst(*t, mapping) { if let Some(t1) = self.subst_impl(*t, mapping, cache) {
new_ty.to_mut()[i] = t1; new_ty.to_mut()[i] = t1;
} }
} }
@ -817,10 +829,10 @@ impl Unifier {
} }
} }
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t })) self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
} }
TypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty } => {
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })) self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
} }
TypeEnum::TObj { obj_id, fields, params } => { TypeEnum::TObj { obj_id, fields, params } => {
// Type variables in field types must be present in the type parameter. // Type variables in field types must be present in the type parameter.
@ -837,27 +849,32 @@ impl Unifier {
} }
}); });
if need_subst { if need_subst {
cache.insert(a, None);
let obj_id = *obj_id; let obj_id = *obj_id;
let params = self.subst_map(&params, mapping).unwrap_or_else(|| params.clone()); let params = self.subst_map(&params, mapping, cache).unwrap_or_else(|| params.clone());
let fields = self let fields = self
.subst_map(&fields.borrow(), mapping) .subst_map(&fields.borrow(), mapping, cache)
.unwrap_or_else(|| fields.borrow().clone()); .unwrap_or_else(|| fields.borrow().clone());
Some(self.add_ty(TypeEnum::TObj { let new_ty = self.add_ty(TypeEnum::TObj {
obj_id, obj_id,
params: params.into(), params: params.into(),
fields: fields.into(), fields: fields.into(),
})) });
if let Some(var) = cache.get(&a).unwrap() {
self.unify(new_ty, *var).unwrap();
}
Some(new_ty)
} else { } else {
None None
} }
} }
TypeEnum::TFunc(sig) => { TypeEnum::TFunc(sig) => {
let FunSignature { args, ret, vars: params } = &*sig.borrow(); let FunSignature { args, ret, vars: params } = &*sig.borrow();
let new_params = self.subst_map(params, mapping); let new_params = self.subst_map(params, mapping, cache);
let new_ret = self.subst(*ret, mapping); let new_ret = self.subst_impl(*ret, mapping, cache);
let mut new_args = Cow::from(args); let mut new_args = Cow::from(args);
for (i, t) in args.iter().enumerate() { for (i, t) in args.iter().enumerate() {
if let Some(t1) = self.subst(t.ty, mapping) { if let Some(t1) = self.subst_impl(t.ty, mapping, cache) {
let mut t = t.clone(); let mut t = t.clone();
t.ty = t1; t.ty = t1;
new_args.to_mut()[i] = t; new_args.to_mut()[i] = t;
@ -880,13 +897,13 @@ impl Unifier {
} }
} }
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>> fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Mapping<K>>
where where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
{ {
let mut map2 = None; let mut map2 = None;
for (k, v) in map.iter() { for (k, v) in map.iter() {
if let Some(v1) = self.subst(*v, mapping) { if let Some(v1) = self.subst_impl(*v, mapping, cache) {
if map2.is_none() { if map2.is_none() {
map2 = Some(map.clone()); map2 = Some(map.clone());
} }

View File

@ -67,7 +67,7 @@ impl Unifier {
struct TestEnvironment { struct TestEnvironment {
pub unifier: Unifier, pub unifier: Unifier,
type_mapping: HashMap<String, Type>, pub type_mapping: HashMap<String, Type>,
} }
impl TestEnvironment { impl TestEnvironment {
@ -325,6 +325,30 @@ fn test_invalid_unification(
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string())); assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
} }
#[test]
fn test_recursive_subst() {
let mut env = TestEnvironment::new();
let int = *env.type_mapping.get("int").unwrap();
let foo_id = *env.type_mapping.get("Foo").unwrap();
let foo_ty = env.unifier.get_ty(foo_id);
let mapping: HashMap<_, _>;
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
fields.borrow_mut().insert("rec".into(), foo_id);
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
} else {
unreachable!()
}
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated);
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
let fields = fields.borrow();
assert!(env.unifier.unioned(*fields.get("a").unwrap(), int));
assert!(env.unifier.unioned(*fields.get("rec").unwrap(), instantiated));
} else {
unreachable!()
}
}
#[test] #[test]
fn test_virtual() { fn test_virtual() {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();