diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index c708c86c3..608ed40e9 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -794,7 +794,19 @@ impl Unifier { /// If this returns None, the result type would be the original type /// (no substitution has to be done). pub fn subst(&mut self, a: Type, mapping: &VarMap) -> Option { + self.subst_impl(a, mapping, &mut HashMap::new()) + } + + fn subst_impl(&mut self, a: Type, mapping: &VarMap, cache: &mut HashMap>) -> Option { use TypeVarMeta::*; + let cached = cache.get_mut(&a); + if let Some(cached) = cached { + if cached.is_none() { + *cached = Some(self.get_fresh_var().0); + } + return *cached; + } + let ty = self.unification_table.probe_value(a).clone(); // this function would only be called when we instantiate functions. // function type signature should ONLY contain concrete types and type @@ -806,7 +818,7 @@ impl Unifier { TypeEnum::TTuple { ty } => { let mut new_ty = Cow::from(ty); for (i, t) in ty.iter().enumerate() { - if let Some(t1) = self.subst(*t, mapping) { + if let Some(t1) = self.subst_impl(*t, mapping, cache) { new_ty.to_mut()[i] = t1; } } @@ -817,10 +829,10 @@ impl Unifier { } } TypeEnum::TList { ty } => { - self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TList { ty: t })) + self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t })) } TypeEnum::TVirtual { ty } => { - self.subst(*ty, mapping).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })) + self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })) } TypeEnum::TObj { obj_id, fields, params } => { // Type variables in field types must be present in the type parameter. @@ -837,27 +849,32 @@ impl Unifier { } }); if need_subst { + cache.insert(a, None); let obj_id = *obj_id; - let params = self.subst_map(¶ms, mapping).unwrap_or_else(|| params.clone()); + let params = self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone()); let fields = self - .subst_map(&fields.borrow(), mapping) + .subst_map(&fields.borrow(), mapping, cache) .unwrap_or_else(|| fields.borrow().clone()); - Some(self.add_ty(TypeEnum::TObj { + let new_ty = self.add_ty(TypeEnum::TObj { obj_id, params: params.into(), fields: fields.into(), - })) + }); + if let Some(var) = cache.get(&a).unwrap() { + self.unify(new_ty, *var).unwrap(); + } + Some(new_ty) } else { None } } TypeEnum::TFunc(sig) => { let FunSignature { args, ret, vars: params } = &*sig.borrow(); - let new_params = self.subst_map(params, mapping); - let new_ret = self.subst(*ret, mapping); + let new_params = self.subst_map(params, mapping, cache); + let new_ret = self.subst_impl(*ret, mapping, cache); let mut new_args = Cow::from(args); for (i, t) in args.iter().enumerate() { - if let Some(t1) = self.subst(t.ty, mapping) { + if let Some(t1) = self.subst_impl(t.ty, mapping, cache) { let mut t = t.clone(); t.ty = t1; new_args.to_mut()[i] = t; @@ -880,13 +897,13 @@ impl Unifier { } } - fn subst_map(&mut self, map: &Mapping, mapping: &VarMap) -> Option> + fn subst_map(&mut self, map: &Mapping, mapping: &VarMap, cache: &mut HashMap>) -> Option> where K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { let mut map2 = None; for (k, v) in map.iter() { - if let Some(v1) = self.subst(*v, mapping) { + if let Some(v1) = self.subst_impl(*v, mapping, cache) { if map2.is_none() { map2 = Some(map.clone()); } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 2940c2767..d652ceca4 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -67,7 +67,7 @@ impl Unifier { struct TestEnvironment { pub unifier: Unifier, - type_mapping: HashMap, + pub type_mapping: HashMap, } impl TestEnvironment { @@ -325,6 +325,30 @@ fn test_invalid_unification( assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string())); } +#[test] +fn test_recursive_subst() { + let mut env = TestEnvironment::new(); + let int = *env.type_mapping.get("int").unwrap(); + let foo_id = *env.type_mapping.get("Foo").unwrap(); + let foo_ty = env.unifier.get_ty(foo_id); + let mapping: HashMap<_, _>; + if let TypeEnum::TObj { fields, params, .. } = &*foo_ty { + fields.borrow_mut().insert("rec".into(), foo_id); + mapping = params.borrow().iter().map(|(id, _)| (*id, int)).collect(); + } else { + unreachable!() + } + let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); + let instantiated_ty = env.unifier.get_ty(instantiated); + if let TypeEnum::TObj { fields, .. } = &*instantiated_ty { + let fields = fields.borrow(); + assert!(env.unifier.unioned(*fields.get("a").unwrap(), int)); + assert!(env.unifier.unioned(*fields.get("rec").unwrap(), instantiated)); + } else { + unreachable!() + } +} + #[test] fn test_virtual() { let mut env = TestEnvironment::new();