diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 318615f..f4acb80 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -116,6 +116,80 @@ impl Unifier { } } + // copy concrete type from a type in another unifier + // note that we are constructing a new type this way + // this can handle recursive types + pub fn copy_from( + &mut self, + unifier: &mut Unifier, + ty: Type, + type_cache: &mut HashMap, + ) -> Type { + let representative = self.get_representative(ty); + type_cache.get(&representative).cloned().unwrap_or_else(|| { + // put in a placeholder first to handle possible recursive type + let placeholder = self.get_fresh_var().0; + type_cache.insert(representative, placeholder); + let ty = match &*self.get_ty(ty) { + TypeEnum::TVar { .. } | TypeEnum::TRigidVar { .. } | TypeEnum::TCall(..) => { + unreachable!() + } + TypeEnum::TObj { obj_id, fields, params } => TypeEnum::TObj { + obj_id: *obj_id, + fields: RefCell::new( + fields + .borrow() + .iter() + .map(|(name, ty)| { + (name.clone(), self.copy_from(unifier, *ty, type_cache)) + }) + .collect(), + ), + params: RefCell::new( + params + .borrow() + .iter() + .map(|(id, ty)| (*id, self.copy_from(unifier, *ty, type_cache))) + .collect(), + ), + }, + TypeEnum::TList { ty } => { + TypeEnum::TList { ty: self.copy_from(unifier, *ty, type_cache) } + } + TypeEnum::TFunc(fun) => { + let fun = fun.borrow(); + TypeEnum::TFunc(RefCell::new(FunSignature { + args: fun + .args + .iter() + .map(|arg| FuncArg { + name: arg.name.clone(), + ty: self.copy_from(unifier, arg.ty, type_cache), + default_value: arg.default_value.clone(), + }) + .collect(), + ret: self.copy_from(unifier, fun.ret, type_cache), + vars: fun + .vars + .iter() + .map(|(id, ty)| (*id, self.copy_from(unifier, *ty, type_cache))) + .collect(), + })) + } + TypeEnum::TTuple { ty } => TypeEnum::TTuple { + ty: ty.iter().map(|ty| self.copy_from(unifier, *ty, type_cache)).collect(), + }, + TypeEnum::TVirtual { ty } => { + TypeEnum::TVirtual { ty: self.copy_from(unifier, *ty, type_cache) } + } + }; + let ty = unifier.add_ty(ty); + self.unify(placeholder, ty).unwrap(); + type_cache.insert(representative, ty); + ty + }) + } + /// Determine if the two types are the same pub fn unioned(&mut self, a: Type, b: Type) -> bool { self.unification_table.unioned(a, b)