diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 0164641f..33cda9db 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -204,10 +204,9 @@ impl Unifier { } for v1 in old_range2.iter() { for v2 in range1.iter() { - if !self.shape_match(*v1, *v2) { - continue; + if let Ok(result) = self.shape_match(*v1, *v2){ + range2.push(result.unwrap_or(*v2)); } - range2.push(*v2); } } if range2.is_empty() { @@ -648,39 +647,116 @@ impl Unifier { Ok(()) } - fn shape_match(&mut self, a: Type, b: Type) -> bool { + fn shape_match(&mut self, a: Type, b: Type) -> Result, ()> { use TypeEnum::*; let x = self.get_ty(a); let y = self.get_ty(b); match (x.as_ref(), y.as_ref()) { - (TVar { id, range, .. }, _) => self.check_var_compatible(*id, b, &range.borrow()).is_ok(), - (_, TVar { id, range, .. }) => self.check_var_compatible(*id, a, &range.borrow()).is_ok(), - (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { - ty1.len() == ty2.len() - && zip(ty1.iter(), ty2.iter()).all(|(a, b)| self.shape_match(*a, *b)) + (TVar { range: range1, .. }, TVar { meta, range: range2, .. }) => { + // we should restrict range2 + let range1 = range1.borrow(); + // new range is the intersection of them + // empty range indicates no constraint + if !range1.is_empty() { + let range2 = range2.borrow(); + let mut range = Vec::new(); + if range2.is_empty() { + range.extend_from_slice(&range1); + } + for v1 in range2.iter() { + for v2 in range1.iter() { + let result = self.shape_match(*v1, *v2); + if let Ok(result) = result { + range.push(result.unwrap_or(*v2)); + } + } + } + if range.is_empty() { + Err(()) + } else { + let id = self.var_id + 1; + self.var_id += 1; + let ty = TVar { id, meta: meta.clone(), range: range.into() }; + Ok(Some(self.unification_table.new_key(ty.into()))) + } + } else { + Ok(Some(b)) + } + } + (_, TVar { range, .. }) => { + // range should be restricted to the left hand side + let range = range.borrow(); + if range.is_empty() { + Ok(Some(a)) + } else { + for v in range.iter() { + let result = self.shape_match(a, *v); + if let Ok(result) = result { + return Ok(result.or(Some(a))); + } + } + Err(()) + } + } + (TVar { id, range, .. }, _) => { + self.check_var_compatible(*id, b, &range.borrow()).or(Err(())) + } + (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { + if ty1.len() != ty2.len() { + return Err(()); + } + let mut need_new = false; + let mut ty = ty1.clone(); + for (a, b) in zip(ty1.iter(), ty2.iter()) { + let result = self.shape_match(*a, *b)?; + ty.push(result.unwrap_or(*a)); + if result.is_some() { + need_new = true; + } + } + if need_new { + Ok(Some(self.add_ty(TTuple { ty }))) + } else { + Ok(None) + } + } + (TList { ty: ty1 }, TList { ty: ty2 }) => { + Ok(self.shape_match(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) + } + (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { + Ok(self.shape_match(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) + } + (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => { + if id1 == id2 { + Ok(None) + } else { + Err(()) + } } - (TList { ty: ty1 }, TList { ty: ty2 }) - | (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => self.shape_match(*ty1, *ty2), - (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => id1 == id2, // don't deal with function shape for now - _ => false, + _ => Err(()), } } - fn check_var_compatible(&mut self, id: u32, b: Type, range: &[Type]) -> Result<(), String> { - let mut in_range = range.is_empty(); + fn check_var_compatible( + &mut self, + id: u32, + b: Type, + range: &[Type], + ) -> Result, String> { + if range.is_empty() { + return Ok(None); + } for t in range.iter() { - if self.shape_match(*t, b) { - in_range = true; + let result = self.shape_match(*t, b); + if let Ok(result) = result { + return Ok(result); } } - if !in_range { - return Err(format!( - "Cannot unify type variable {} with {} due to incompatible value range", - id, - self.get_ty(b).get_type_name() - )); - } - Ok(()) + return Err(format!( + "Cannot unify type variable {} with {} due to incompatible value range", + id, + self.get_ty(b).get_type_name() + )); } } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index f7e09084..993e7333 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -403,4 +403,40 @@ fn test_typevar_range() { env.unifier.unify(float_list, v), Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string()) ); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + env.unifier.unify(a, b).unwrap(); + env.unifier.unify(a, float).unwrap(); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + env.unifier.unify(a, b).unwrap(); + assert_eq!( + env.unifier.unify(a, int), + Err("Cannot unify type variable 12 with TObj due to incompatible value range".into()) + ); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a}); + let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0; + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b}); + let b_list = env.unifier.get_fresh_var_with_range(&[b_list]).0; + env.unifier.unify(a_list, b_list).unwrap(); + let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float}); + env.unifier.unify(a_list, float_list).unwrap(); + // previous unifications should not affect a and b + env.unifier.unify(a, int).unwrap(); + + let a = env.unifier.get_fresh_var_with_range(&[int, float]).0; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float]).0; + let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a}); + let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b}); + env.unifier.unify(a_list, b_list).unwrap(); + let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int}); + assert_eq!( + env.unifier.unify(a_list, int_list), + Err("Cannot unify type variable 19 with TObj due to incompatible value range".into()) + ); }