diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 33cda9db..25a2251a 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -204,7 +204,7 @@ impl Unifier { } for v1 in old_range2.iter() { for v2 in range1.iter() { - if let Ok(result) = self.shape_match(*v1, *v2){ + if let Ok(result) = self.get_intersection(*v1, *v2){ range2.push(result.unwrap_or(*v2)); } } @@ -219,8 +219,9 @@ impl Unifier { } (TVar { meta: Generic, id, range, .. }, _) => { self.occur_check(a, b)?; - self.check_var_compatible(*id, b, &range.borrow())?; - self.set_a_to_b(a, b); + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); } (TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => { self.occur_check(a, b)?; @@ -236,16 +237,18 @@ impl Unifier { } self.unify(*v, ty[ind as usize])?; } - self.check_var_compatible(*id, b, &range.borrow())?; - self.set_a_to_b(a, b); + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); } (TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => { self.occur_check(a, b)?; for v in map.borrow().values() { self.unify(*v, *ty)?; } - self.check_var_compatible(*id, b, &range.borrow())?; - self.set_a_to_b(a, b); + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); } (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { if ty1.len() != ty2.len() { @@ -273,8 +276,9 @@ impl Unifier { return Err(format!("No such attribute {}", k)); } } - self.check_var_compatible(*id, b, &range.borrow())?; - self.set_a_to_b(a, b); + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); } (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { self.occur_check(a, b)?; @@ -294,8 +298,9 @@ impl Unifier { // require annotation... return Err("Requires type annotation for virtual".to_string()); } - self.check_var_compatible(*id, b, &range.borrow())?; - self.set_a_to_b(a, b); + let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b); + self.unify(x, b)?; + self.set_a_to_b(a, x); } ( TObj { obj_id: id1, params: params1, .. }, @@ -647,7 +652,7 @@ impl Unifier { Ok(()) } - fn shape_match(&mut self, a: Type, b: Type) -> Result, ()> { + fn get_intersection(&mut self, a: Type, b: Type) -> Result, ()> { use TypeEnum::*; let x = self.get_ty(a); let y = self.get_ty(b); @@ -665,7 +670,7 @@ impl Unifier { } for v1 in range2.iter() { for v2 in range1.iter() { - let result = self.shape_match(*v1, *v2); + let result = self.get_intersection(*v1, *v2); if let Ok(result) = result { range.push(result.unwrap_or(*v2)); } @@ -690,7 +695,7 @@ impl Unifier { Ok(Some(a)) } else { for v in range.iter() { - let result = self.shape_match(a, *v); + let result = self.get_intersection(a, *v); if let Ok(result) = result { return Ok(result.or(Some(a))); } @@ -699,7 +704,7 @@ impl Unifier { } } (TVar { id, range, .. }, _) => { - self.check_var_compatible(*id, b, &range.borrow()).or(Err(())) + self.check_var_compatibility(*id, b, &range.borrow()).or(Err(())) } (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { if ty1.len() != ty2.len() { @@ -708,7 +713,7 @@ impl Unifier { let mut need_new = false; let mut ty = ty1.clone(); for (a, b) in zip(ty1.iter(), ty2.iter()) { - let result = self.shape_match(*a, *b)?; + let result = self.get_intersection(*a, *b)?; ty.push(result.unwrap_or(*a)); if result.is_some() { need_new = true; @@ -721,10 +726,10 @@ impl Unifier { } } (TList { ty: ty1 }, TList { ty: ty2 }) => { - Ok(self.shape_match(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) + Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) } (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { - Ok(self.shape_match(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) + Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) } (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => { if id1 == id2 { @@ -738,7 +743,7 @@ impl Unifier { } } - fn check_var_compatible( + fn check_var_compatibility( &mut self, id: u32, b: Type, @@ -748,7 +753,7 @@ impl Unifier { return Ok(None); } for t in range.iter() { - let result = self.shape_match(*t, b); + let result = self.get_intersection(*t, b); if let Ok(result) = result { return Ok(result); } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 993e7333..be7401f0 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -439,4 +439,15 @@ fn test_typevar_range() { env.unifier.unify(a_list, int_list), Err("Cannot unify type variable 19 with TObj due to incompatible value range".into()) ); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var().0; + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a}); + let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b}); + env.unifier.unify(a_list, b_list).unwrap(); + assert_eq!( + env.unifier.unify(b, boolean), + Err("Cannot unify type variable 21 with TObj due to incompatible value range".into()) + ); }