hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
2 changed files with 36 additions and 20 deletions
Showing only changes of commit bf31c48bba - Show all commits

View File

@ -204,7 +204,7 @@ impl Unifier {
} }
for v1 in old_range2.iter() { for v1 in old_range2.iter() {
for v2 in range1.iter() { for v2 in range1.iter() {
if let Ok(result) = self.shape_match(*v1, *v2){ if let Ok(result) = self.get_intersection(*v1, *v2){
range2.push(result.unwrap_or(*v2)); range2.push(result.unwrap_or(*v2));
} }
} }
@ -219,8 +219,9 @@ impl Unifier {
} }
(TVar { meta: Generic, id, range, .. }, _) => { (TVar { meta: Generic, id, range, .. }, _) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
self.check_var_compatible(*id, b, &range.borrow())?; let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.set_a_to_b(a, b); self.unify(x, b)?;
self.set_a_to_b(a, x);
} }
(TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => { (TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
@ -236,16 +237,18 @@ impl Unifier {
} }
self.unify(*v, ty[ind as usize])?; self.unify(*v, ty[ind as usize])?;
} }
self.check_var_compatible(*id, b, &range.borrow())?; let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.set_a_to_b(a, b); self.unify(x, b)?;
self.set_a_to_b(a, x);
} }
(TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => { (TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
for v in map.borrow().values() { for v in map.borrow().values() {
self.unify(*v, *ty)?; self.unify(*v, *ty)?;
} }
self.check_var_compatible(*id, b, &range.borrow())?; let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.set_a_to_b(a, b); self.unify(x, b)?;
self.set_a_to_b(a, x);
} }
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
@ -273,8 +276,9 @@ impl Unifier {
return Err(format!("No such attribute {}", k)); return Err(format!("No such attribute {}", k));
} }
} }
self.check_var_compatible(*id, b, &range.borrow())?; let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.set_a_to_b(a, b); self.unify(x, b)?;
self.set_a_to_b(a, x);
} }
(TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => {
self.occur_check(a, b)?; self.occur_check(a, b)?;
@ -294,8 +298,9 @@ impl Unifier {
// require annotation... // require annotation...
return Err("Requires type annotation for virtual".to_string()); return Err("Requires type annotation for virtual".to_string());
} }
self.check_var_compatible(*id, b, &range.borrow())?; let x = self.check_var_compatibility(*id, b, &range.borrow())?.unwrap_or(b);
self.set_a_to_b(a, b); self.unify(x, b)?;
self.set_a_to_b(a, x);
} }
( (
TObj { obj_id: id1, params: params1, .. }, TObj { obj_id: id1, params: params1, .. },
@ -647,7 +652,7 @@ impl Unifier {
Ok(()) Ok(())
} }
fn shape_match(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> { fn get_intersection(&mut self, a: Type, b: Type) -> Result<Option<Type>, ()> {
use TypeEnum::*; use TypeEnum::*;
let x = self.get_ty(a); let x = self.get_ty(a);
let y = self.get_ty(b); let y = self.get_ty(b);
@ -665,7 +670,7 @@ impl Unifier {
} }
for v1 in range2.iter() { for v1 in range2.iter() {
for v2 in range1.iter() { for v2 in range1.iter() {
let result = self.shape_match(*v1, *v2); let result = self.get_intersection(*v1, *v2);
if let Ok(result) = result { if let Ok(result) = result {
range.push(result.unwrap_or(*v2)); range.push(result.unwrap_or(*v2));
} }
@ -690,7 +695,7 @@ impl Unifier {
Ok(Some(a)) Ok(Some(a))
} else { } else {
for v in range.iter() { for v in range.iter() {
let result = self.shape_match(a, *v); let result = self.get_intersection(a, *v);
if let Ok(result) = result { if let Ok(result) = result {
return Ok(result.or(Some(a))); return Ok(result.or(Some(a)));
} }
@ -699,7 +704,7 @@ impl Unifier {
} }
} }
(TVar { id, range, .. }, _) => { (TVar { id, range, .. }, _) => {
self.check_var_compatible(*id, b, &range.borrow()).or(Err(())) self.check_var_compatibility(*id, b, &range.borrow()).or(Err(()))
} }
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
@ -708,7 +713,7 @@ impl Unifier {
let mut need_new = false; let mut need_new = false;
let mut ty = ty1.clone(); let mut ty = ty1.clone();
for (a, b) in zip(ty1.iter(), ty2.iter()) { for (a, b) in zip(ty1.iter(), ty2.iter()) {
let result = self.shape_match(*a, *b)?; let result = self.get_intersection(*a, *b)?;
ty.push(result.unwrap_or(*a)); ty.push(result.unwrap_or(*a));
if result.is_some() { if result.is_some() {
need_new = true; need_new = true;
@ -721,10 +726,10 @@ impl Unifier {
} }
} }
(TList { ty: ty1 }, TList { ty: ty2 }) => { (TList { ty: ty1 }, TList { ty: ty2 }) => {
Ok(self.shape_match(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
} }
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
Ok(self.shape_match(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
} }
(TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => { (TObj { obj_id: id1, .. }, TObj { obj_id: id2, .. }) => {
if id1 == id2 { if id1 == id2 {
@ -738,7 +743,7 @@ impl Unifier {
} }
} }
fn check_var_compatible( fn check_var_compatibility(
&mut self, &mut self,
id: u32, id: u32,
b: Type, b: Type,
@ -748,7 +753,7 @@ impl Unifier {
return Ok(None); return Ok(None);
} }
for t in range.iter() { for t in range.iter() {
let result = self.shape_match(*t, b); let result = self.get_intersection(*t, b);
if let Ok(result) = result { if let Ok(result) = result {
return Ok(result); return Ok(result);
} }

View File

@ -439,4 +439,15 @@ fn test_typevar_range() {
env.unifier.unify(a_list, int_list), env.unifier.unify(a_list, int_list),
Err("Cannot unify type variable 19 with TObj due to incompatible value range".into()) Err("Cannot unify type variable 19 with TObj due to incompatible value range".into())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float]).0;
let b = env.unifier.get_fresh_var().0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list]).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b});
env.unifier.unify(a_list, b_list).unwrap();
assert_eq!(
env.unifier.unify(b, boolean),
Err("Cannot unify type variable 21 with TObj due to incompatible value range".into())
);
} }