diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index c97926da..6980f00c 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -47,6 +47,9 @@ pub enum TypeVarMeta { #[derive(Clone)] pub enum TypeEnum { + TRigidVar { + id: u32, + }, TVar { id: u32, meta: TypeVarMeta, @@ -74,6 +77,7 @@ pub enum TypeEnum { impl TypeEnum { pub fn get_type_name(&self) -> &'static str { match self { + TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TVar { .. } => "TVar", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", @@ -127,6 +131,12 @@ impl Unifier { self.unification_table.probe_value(a).clone() } + pub fn get_fresh_rigid_var(&mut self) -> (Type, u32) { + let id = self.var_id + 1; + self.var_id += 1; + (self.add_ty(TypeEnum::TRigidVar { id }), id) + } + pub fn get_fresh_var(&mut self) -> (Type, u32) { self.get_fresh_var_with_range(&[]) } @@ -139,9 +149,17 @@ impl Unifier { (self.add_ty(TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic }), id) } + /// Unification would not unify rigid variables with other types, but we want to do this for + /// function instantiations, so we make it explicit. + pub fn replace_rigid_var(&mut self, rigid: Type, b: Type) { + assert!(matches!(&*self.get_ty(rigid), TypeEnum::TRigidVar { .. })); + self.set_a_to_b(rigid, b); + } + pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { use TypeEnum::*; match &*self.get_ty(a) { + TRigidVar { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } => self.is_concrete(*ty, allowed_typevars), @@ -435,6 +453,7 @@ impl Unifier { use TypeVarMeta::*; let ty = self.unification_table.probe_value(ty).clone(); match ty.as_ref() { + TypeEnum::TRigidVar { id } => var_to_name(*id), TypeEnum::TVar { id, meta: Generic, .. } => var_to_name(*id), TypeEnum::TVar { meta: Sequence(map), .. } => { let fields = map @@ -544,6 +563,7 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { + TypeEnum::TRigidVar { .. } => None, TypeEnum::TVar { id, meta: Generic, .. } => mapping.get(&id).cloned(), TypeEnum::TTuple { ty } => { let mut new_ty = Cow::from(ty); @@ -634,7 +654,7 @@ impl Unifier { let ty = self.unification_table.probe_value(b).clone(); match ty.as_ref() { - TypeEnum::TVar { meta: Generic, .. } => {} + TypeEnum::TRigidVar { .. } | TypeEnum::TVar { meta: Generic, .. } => {} TypeEnum::TVar { meta: Sequence(map), .. } => { for t in map.borrow().values() { self.occur_check(a, *t)?; diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index be7401f0..cf0cc9c0 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -419,22 +419,22 @@ fn test_typevar_range() { let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; - let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a}); + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; - let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b}); + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0; env.unifier.unify(a_list, b_list).unwrap(); - let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float}); + let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); env.unifier.unify(a_list, float_list).unwrap(); // previous unifications should not affect a and b env.unifier.unify(a, int).unwrap(); let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; - let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a}); - let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b}); + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); env.unifier.unify(a_list, b_list).unwrap(); - let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int}); + let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int }); assert_eq!( env.unifier.unify(a_list, int_list), Err("Cannot unify type variable 19 with TObj due to incompatible value range".into()) @@ -442,12 +442,34 @@ fn test_typevar_range() { let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; let b = env.unifier.get_fresh_var().0; - let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a}); + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; - let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b}); + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); env.unifier.unify(a_list, b_list).unwrap(); assert_eq!( env.unifier.unify(b, boolean), Err("Cannot unify type variable 21 with TObj due to incompatible value range".into()) ); } + +#[test] +fn test_rigid_var() { + let mut env = TestEnvironment::new(); + let a = env.unifier.get_fresh_rigid_var().0; + let b = env.unifier.get_fresh_rigid_var().0; + let x = env.unifier.get_fresh_var().0; + let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); + let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); + let int = env.parse("int", &HashMap::new()); + let list_int = env.parse("List[int]", &HashMap::new()); + + assert_eq!(env.unifier.unify(a, b), Err("Cannot unify TRigidVar with TRigidVar".to_string())); + env.unifier.unify(list_a, list_x).unwrap(); + assert_eq!( + env.unifier.unify(list_x, list_int), + Err("Cannot unify TObj with TRigidVar".to_string()) + ); + + env.unifier.replace_rigid_var(a, int); + env.unifier.unify(list_x, list_int).unwrap(); +}