forked from M-Labs/nac3
typecheck: fixed recursive substitution
This commit is contained in:
parent
471547855e
commit
180392e2ab
|
@ -794,7 +794,19 @@ impl Unifier {
|
||||||
/// If this returns None, the result type would be the original type
|
/// If this returns None, the result type would be the original type
|
||||||
/// (no substitution has to be done).
|
/// (no substitution has to be done).
|
||||||
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option<Type> {
|
||||||
|
self.subst_impl(a, mapping, &mut HashMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn subst_impl(&mut self, a: Type, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Type> {
|
||||||
use TypeVarMeta::*;
|
use TypeVarMeta::*;
|
||||||
|
let cached = cache.get_mut(&a);
|
||||||
|
if let Some(cached) = cached {
|
||||||
|
if cached.is_none() {
|
||||||
|
*cached = Some(self.get_fresh_var().0);
|
||||||
|
}
|
||||||
|
return *cached;
|
||||||
|
}
|
||||||
|
|
||||||
let ty = self.unification_table.probe_value(a).clone();
|
let ty = self.unification_table.probe_value(a).clone();
|
||||||
// this function would only be called when we instantiate functions.
|
// this function would only be called when we instantiate functions.
|
||||||
// function type signature should ONLY contain concrete types and type
|
// function type signature should ONLY contain concrete types and type
|
||||||
|
@ -806,7 +818,7 @@ impl Unifier {
|
||||||
TypeEnum::TTuple { ty } => {
|
TypeEnum::TTuple { ty } => {
|
||||||
let mut new_ty = Cow::from(ty);
|
let mut new_ty = Cow::from(ty);
|
||||||
for (i, t) in ty.iter().enumerate() {
|
for (i, t) in ty.iter().enumerate() {
|
||||||
if let Some(t1) = self.subst(*t, mapping) {
|
if let Some(t1) = self.subst_impl(*t, mapping, cache) {
|
||||||
new_ty.to_mut()[i] = t1;
|
new_ty.to_mut()[i] = t1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -817,10 +829,10 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||||
}
|
}
|
||||||
TypeEnum::TVirtual { ty } => {
|
TypeEnum::TVirtual { ty } => {
|
||||||
self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, fields, params } => {
|
TypeEnum::TObj { obj_id, fields, params } => {
|
||||||
// Type variables in field types must be present in the type parameter.
|
// Type variables in field types must be present in the type parameter.
|
||||||
|
@ -837,27 +849,32 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
if need_subst {
|
if need_subst {
|
||||||
|
cache.insert(a, None);
|
||||||
let obj_id = *obj_id;
|
let obj_id = *obj_id;
|
||||||
let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone());
|
let params = self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
||||||
let fields = self
|
let fields = self
|
||||||
.subst_map(&fields.borrow(), mapping)
|
.subst_map(&fields.borrow(), mapping, cache)
|
||||||
.unwrap_or_else(|| fields.borrow().clone());
|
.unwrap_or_else(|| fields.borrow().clone());
|
||||||
Some(self.add_ty(TypeEnum::TObj {
|
let new_ty = self.add_ty(TypeEnum::TObj {
|
||||||
obj_id,
|
obj_id,
|
||||||
params: params.into(),
|
params: params.into(),
|
||||||
fields: fields.into(),
|
fields: fields.into(),
|
||||||
}))
|
});
|
||||||
|
if let Some(var) = cache.get(&a).unwrap() {
|
||||||
|
self.unify(new_ty, *var).unwrap();
|
||||||
|
}
|
||||||
|
Some(new_ty)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TFunc(sig) => {
|
TypeEnum::TFunc(sig) => {
|
||||||
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
let FunSignature { args, ret, vars: params } = &*sig.borrow();
|
||||||
let new_params = self.subst_map(params, mapping);
|
let new_params = self.subst_map(params, mapping, cache);
|
||||||
let new_ret = self.subst(*ret, mapping);
|
let new_ret = self.subst_impl(*ret, mapping, cache);
|
||||||
let mut new_args = Cow::from(args);
|
let mut new_args = Cow::from(args);
|
||||||
for (i, t) in args.iter().enumerate() {
|
for (i, t) in args.iter().enumerate() {
|
||||||
if let Some(t1) = self.subst(t.ty, mapping) {
|
if let Some(t1) = self.subst_impl(t.ty, mapping, cache) {
|
||||||
let mut t = t.clone();
|
let mut t = t.clone();
|
||||||
t.ty = t1;
|
t.ty = t1;
|
||||||
new_args.to_mut()[i] = t;
|
new_args.to_mut()[i] = t;
|
||||||
|
@ -880,13 +897,13 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap) -> Option<Mapping<K>>
|
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Mapping<K>>
|
||||||
where
|
where
|
||||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||||
{
|
{
|
||||||
let mut map2 = None;
|
let mut map2 = None;
|
||||||
for (k, v) in map.iter() {
|
for (k, v) in map.iter() {
|
||||||
if let Some(v1) = self.subst(*v, mapping) {
|
if let Some(v1) = self.subst_impl(*v, mapping, cache) {
|
||||||
if map2.is_none() {
|
if map2.is_none() {
|
||||||
map2 = Some(map.clone());
|
map2 = Some(map.clone());
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ impl Unifier {
|
||||||
|
|
||||||
struct TestEnvironment {
|
struct TestEnvironment {
|
||||||
pub unifier: Unifier,
|
pub unifier: Unifier,
|
||||||
type_mapping: HashMap<String, Type>,
|
pub type_mapping: HashMap<String, Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestEnvironment {
|
impl TestEnvironment {
|
||||||
|
@ -325,6 +325,30 @@ fn test_invalid_unification(
|
||||||
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
|
assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_recursive_subst() {
|
||||||
|
let mut env = TestEnvironment::new();
|
||||||
|
let int = *env.type_mapping.get("int").unwrap();
|
||||||
|
let foo_id = *env.type_mapping.get("Foo").unwrap();
|
||||||
|
let foo_ty = env.unifier.get_ty(foo_id);
|
||||||
|
let mapping: HashMap<_, _>;
|
||||||
|
if let TypeEnum::TObj { fields, params, .. } = &*foo_ty {
|
||||||
|
fields.borrow_mut().insert("rec".into(), foo_id);
|
||||||
|
mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect();
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
|
||||||
|
let instantiated_ty = env.unifier.get_ty(instantiated);
|
||||||
|
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
|
||||||
|
let fields = fields.borrow();
|
||||||
|
assert!(env.unifier.unioned(*fields.get("a").unwrap(), int));
|
||||||
|
assert!(env.unifier.unioned(*fields.get("rec").unwrap(), instantiated));
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_virtual() {
|
fn test_virtual() {
|
||||||
let mut env = TestEnvironment::new();
|
let mut env = TestEnvironment::new();
|
||||||
|
|
Loading…
Reference in New Issue