diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index ac1db79..0164641 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -207,7 +207,6 @@ impl Unifier { if !self.shape_match(*v1, *v2) { continue; } - self.unify(*v1, *v2)?; range2.push(*v2); } } @@ -221,7 +220,7 @@ impl Unifier { } (TVar { meta: Generic, id, range, .. }, _) => { self.occur_check(a, b)?; - self.check_var_range(*id, b, &range.borrow())?; + self.check_var_compatible(*id, b, &range.borrow())?; self.set_a_to_b(a, b); } (TVar { meta: Sequence(map), id, range, .. }, TTuple { ty }) => { @@ -238,7 +237,7 @@ impl Unifier { } self.unify(*v, ty[ind as usize])?; } - self.check_var_range(*id, b, &range.borrow())?; + self.check_var_compatible(*id, b, &range.borrow())?; self.set_a_to_b(a, b); } (TVar { meta: Sequence(map), id, range, .. }, TList { ty }) => { @@ -246,7 +245,7 @@ impl Unifier { for v in map.borrow().values() { self.unify(*v, *ty)?; } - self.check_var_range(*id, b, &range.borrow())?; + self.check_var_compatible(*id, b, &range.borrow())?; self.set_a_to_b(a, b); } (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { @@ -275,7 +274,7 @@ impl Unifier { return Err(format!("No such attribute {}", k)); } } - self.check_var_range(*id, b, &range.borrow())?; + self.check_var_compatible(*id, b, &range.borrow())?; self.set_a_to_b(a, b); } (TVar { meta: Record(map), id, range, .. }, TVirtual { ty }) => { @@ -288,14 +287,16 @@ impl Unifier { return Err(format!("Cannot access field {} for virtual type", k)); } self.unify(*v, *ty)?; + } else { + return Err(format!("No such attribute {}", k)); } } } else { // require annotation... return Err("Requires type annotation for virtual".to_string()); } - self.check_var_range(*id, b, &range.borrow())?; - self.unify(a, b)?; + self.check_var_compatible(*id, b, &range.borrow())?; + self.set_a_to_b(a, b); } ( TObj { obj_id: id1, params: params1, .. }, @@ -457,24 +458,6 @@ impl Unifier { } } - fn check_var_range(&mut self, id: u32, b: Type, range: &[Type]) -> Result<(), String> { - let mut in_range = range.is_empty(); - for t in range.iter() { - if self.shape_match(*t, b) { - self.unify(*t, b)?; - in_range = true; - } - } - if !in_range { - return Err(format!( - "Cannot unify {} with {} due to incompatible value range", - id, - self.get_ty(b).get_type_name() - )); - } - Ok(()) - } - fn set_a_to_b(&mut self, a: Type, b: Type) { // unify a and b together, and set the value to b's value. let table = &mut self.unification_table; @@ -665,13 +648,13 @@ impl Unifier { Ok(()) } - pub fn shape_match(&mut self, a: Type, b: Type) -> bool { + fn shape_match(&mut self, a: Type, b: Type) -> bool { use TypeEnum::*; - let a = self.get_ty(a); - let b = self.get_ty(b); - match (a.as_ref(), b.as_ref()) { - (TVar { .. }, _) => true, - (_, TVar { .. }) => true, + let x = self.get_ty(a); + let y = self.get_ty(b); + match (x.as_ref(), y.as_ref()) { + (TVar { id, range, .. }, _) => self.check_var_compatible(*id, b, &range.borrow()).is_ok(), + (_, TVar { id, range, .. }) => self.check_var_compatible(*id, a, &range.borrow()).is_ok(), (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { ty1.len() == ty2.len() && zip(ty1.iter(), ty2.iter()).all(|(a, b)| self.shape_match(*a, *b)) @@ -683,4 +666,21 @@ impl Unifier { _ => false, } } + + fn check_var_compatible(&mut self, id: u32, b: Type, range: &[Type]) -> Result<(), String> { + let mut in_range = range.is_empty(); + for t in range.iter() { + if self.shape_match(*t, b) { + in_range = true; + } + } + if !in_range { + return Err(format!( + "Cannot unify type variable {} with {} due to incompatible value range", + id, + self.get_ty(b).get_type_name() + )); + } + Ok(()) + } } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 8732b1d..f7e0908 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -322,3 +322,85 @@ fn test_invalid_unification( } assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string())); } + +#[test] +fn test_virtual() { + let mut env = TestEnvironment::new(); + let int = env.parse("int", &HashMap::new()); + let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: int, + vars: HashMap::new(), + })); + let bar = env.unifier.add_ty(TypeEnum::TObj { + obj_id: 5, + fields: [("f".to_string(), fun), ("a".to_string(), int)].iter().cloned().collect(), + params: HashMap::new(), + }); + let v0 = env.unifier.get_fresh_var().0; + let v1 = env.unifier.get_fresh_var().0; + + let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); + let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); + let c = env.unifier.add_record([("f".to_string(), v1)].iter().cloned().collect()); + env.unifier.unify(a, b).unwrap(); + env.unifier.unify(b, c).unwrap(); + assert!(env.unifier.eq(v1, fun)); + + let d = env.unifier.add_record([("a".to_string(), v1)].iter().cloned().collect()); + assert_eq!(env.unifier.unify(b, d), Err("Cannot access field a for virtual type".to_string())); + + let d = env.unifier.add_record([("b".to_string(), v1)].iter().cloned().collect()); + assert_eq!(env.unifier.unify(b, d), Err("No such attribute b".to_string())); +} + +#[test] +fn test_typevar_range() { + let mut env = TestEnvironment::new(); + let int = env.parse("int", &HashMap::new()); + let boolean = env.parse("bool", &HashMap::new()); + let float = env.parse("float", &HashMap::new()); + let int_list = env.parse("List[int]", &HashMap::new()); + let float_list = env.parse("List[float]", &HashMap::new()); + + // unification between v and int + // where v in (int, bool) + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + env.unifier.unify(int, v).unwrap(); + + // unification between v and List[int] + // where v in (int, bool) + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + assert_eq!( + env.unifier.unify(int_list, v), + Err("Cannot unify type variable 3 with TList due to incompatible value range".to_string()) + ); + + // unification between v and float + // where v in (int, bool) + let v = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + assert_eq!( + env.unifier.unify(float, v), + Err("Cannot unify type variable 4 with TObj due to incompatible value range".to_string()) + ); + + let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean]).0; + let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + // unification between v and int + // where v in (int, List[v1]), v1 in (int, bool) + env.unifier.unify(int, v).unwrap(); + + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + // unification between v and List[int] + // where v in (int, List[v1]), v1 in (int, bool) + env.unifier.unify(int_list, v).unwrap(); + + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list]).0; + // unification between v and List[float] + // where v in (int, List[v1]), v1 in (int, bool) + assert_eq!( + env.unifier.unify(float_list, v), + Err("Cannot unify type variable 8 with TList due to incompatible value range".to_string()) + ); +}