From d94f25583bdcb4f85b6bae5ccaab7a4cdaba2c5b Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 15 Jul 2021 16:00:23 +0800 Subject: [PATCH] added tests --- Cargo.lock | 15 ++ nac3core/Cargo.toml | 4 + nac3core/src/typecheck/mod.rs | 1 + nac3core/src/typecheck/test_typedef.rs | 273 +++++++++++++++++++++++++ nac3core/src/typecheck/typedef.rs | 251 +++++++++++++---------- 5 files changed, 439 insertions(+), 105 deletions(-) create mode 100644 nac3core/src/typecheck/test_typedef.rs diff --git a/Cargo.lock b/Cargo.lock index c09cc2c62..afc45460f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -403,9 +403,11 @@ dependencies = [ "generational-arena", "indoc 1.0.3", "inkwell", + "itertools", "num-bigint 0.3.2", "num-traits", "rustpython-parser", + "test-case", ] [[package]] @@ -844,6 +846,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "test-case" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b114ece25254e97bf48dd4bfc2a12bad0647adacfe4cae1247a9ca6ad302cec" +dependencies = [ + "cfg-if 1.0.0", + "proc-macro2", + "quote", + "syn", + "version_check", +] + [[package]] name = "tiny-keccak" version = "2.0.2" diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 79ad09978..d3ae77446 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -13,3 +13,7 @@ indoc = "1.0" generational-arena = "0.2" ena = "0.14" +[dev-dependencies] +test-case = "1.2.0" +itertools = "0.10.1" + diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index 118a79ab6..7b30426ee 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -4,4 +4,5 @@ // mod magic_methods; // mod primitives; // pub mod symbol_resolver; +mod test_typedef; pub mod typedef; diff --git a/nac3core/src/typecheck/test_typedef.rs b/nac3core/src/typecheck/test_typedef.rs new file mode 100644 index 000000000..4eeb1cd80 --- /dev/null +++ b/nac3core/src/typecheck/test_typedef.rs @@ -0,0 +1,273 @@ +#[cfg(test)] +mod test { + use super::super::typedef::*; + use itertools::Itertools; + use std::collections::HashMap; + use test_case::test_case; + + struct TestEnvironment { + pub unifier: Unifier, + type_mapping: HashMap, + var_max_id: u32, + } + + impl TestEnvironment { + fn new() -> TestEnvironment { + let unifier = Unifier::new(); + let mut type_mapping = HashMap::new(); + let mut var_max_id = 0; + + type_mapping.insert( + "int".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: 0, + fields: HashMap::new(), + params: HashMap::new(), + }), + ); + type_mapping.insert( + "float".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: 1, + fields: HashMap::new(), + params: HashMap::new(), + }), + ); + type_mapping.insert( + "bool".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: 2, + fields: HashMap::new(), + params: HashMap::new(), + }), + ); + let v0 = unifier.add_ty(TypeEnum::TVar { id: 0 }); + var_max_id += 1; + type_mapping.insert( + "Foo".into(), + unifier.add_ty(TypeEnum::TObj { + obj_id: 3, + fields: [("a".into(), v0)].iter().cloned().collect(), + params: [(0u32, v0)].iter().cloned().collect(), + }), + ); + + TestEnvironment { + unifier, + type_mapping, + var_max_id, + } + } + + fn get_fresh_var(&mut self) -> Type { + let id = self.var_max_id + 1; + self.var_max_id += 1; + self.unifier.add_ty(TypeEnum::TVar { id }) + } + + fn parse(&self, typ: &str, mapping: &Mapping) -> Type { + let result = self.internal_parse(typ, mapping); + assert!(result.1.is_empty()); + result.0 + } + + fn internal_parse<'a, 'b>( + &'a self, + typ: &'b str, + mapping: &Mapping, + ) -> (Type, &'b str) { + // for testing only, so we can just panic when the input is malformed + let end = typ + .find(|c| ['[', ',', ']', '='].contains(&c)) + .unwrap_or_else(|| typ.len()); + match &typ[..end] { + "Tuple" => { + let mut s = &typ[end..]; + assert!(&s[0..1] == "["); + let mut ty = Vec::new(); + while &s[0..1] != "]" { + let result = self.internal_parse(&s[1..], mapping); + ty.push(result.0); + s = result.1; + } + (self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..]) + } + "List" => { + assert!(&typ[end..end + 1] == "["); + let (ty, s) = self.internal_parse(&typ[end + 1..], mapping); + assert!(&s[0..1] == "]"); + (self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..]) + } + "Record" => { + let mut s = &typ[end..]; + assert!(&s[0..1] == "["); + let mut fields = HashMap::new(); + while &s[0..1] != "]" { + let eq = s.find('=').unwrap(); + let key = s[1..eq].to_string(); + let result = self.internal_parse(&s[eq + 1..], mapping); + fields.insert(key, result.0); + s = result.1; + } + (self.unifier.add_ty(TypeEnum::TRecord { fields }), &s[1..]) + } + x => { + let mut s = &typ[end..]; + let ty = mapping.get(x).cloned().unwrap_or_else(|| { + // mapping should be type variables, type_mapping should be concrete types + // we should not resolve the type of type variables. + let mut ty = *self.type_mapping.get(x).unwrap(); + let te = self.unifier.get_ty(ty); + if let TypeEnum::TObj { params, .. } = &*te.as_ref().borrow() { + if !params.is_empty() { + assert!(&s[0..1] == "["); + let mut p = Vec::new(); + while &s[0..1] != "]" { + let result = self.internal_parse(&s[1..], mapping); + p.push(result.0); + s = result.1; + } + s = &s[1..]; + ty = self + .unifier + .subst(ty, ¶ms.keys().cloned().zip(p.into_iter()).collect()) + .unwrap_or(ty); + } + } + ty + }); + (ty, s) + } + } + } + } + + #[test_case(2, + &[("v1", "v2"), ("v2", "float")], + &[("v1", "float"), ("v2", "float")] + ; "simple variable" + )] + #[test_case(2, + &[("v1", "List[v2]"), ("v1", "List[float]")], + &[("v1", "List[float]"), ("v2", "float")] + ; "list element" + )] + #[test_case(3, + &[ + ("v1", "Record[a=v3,b=v3]"), + ("v2", "Record[b=float,c=v3]"), + ("v1", "v2") + ], + &[ + ("v1", "Record[a=float,b=float,c=float]"), + ("v2", "Record[a=float,b=float,c=float]"), + ("v3", "float") + ] + ; "record merge" + )] + #[test_case(3, + &[ + ("v1", "Record[a=float]"), + ("v2", "Foo[v3]"), + ("v1", "v2") + ], + &[ + ("v1", "Foo[float]"), + ("v3", "float") + ] + ; "record obj merge" + )] + fn test_unify( + variable_count: u32, + unify_pairs: &[(&'static str, &'static str)], + verify_pairs: &[(&'static str, &'static str)], + ) { + let unify_count = unify_pairs.len(); + // test all permutations... + for perm in unify_pairs.iter().permutations(unify_count) { + let mut env = TestEnvironment::new(); + let mut mapping = HashMap::new(); + for i in 1..=variable_count { + let v = env.get_fresh_var(); + mapping.insert(format!("v{}", i), v); + } + // unification may have side effect when we do type resolution, so freeze the types + // before doing unification. + let mut pairs = Vec::new(); + for (a, b) in perm.iter() { + let t1 = env.parse(a, &mapping); + let t2 = env.parse(b, &mapping); + pairs.push((t1, t2)); + } + for (t1, t2) in pairs { + env.unifier.unify(t1, t2).unwrap(); + } + for (a, b) in verify_pairs.iter() { + let t1 = env.parse(a, &mapping); + let t2 = env.parse(b, &mapping); + assert!(env.unifier.eq(t1, t2)); + } + } + } + + #[test_case(2, + &[ + ("v1", "Tuple[int]"), + ("v2", "List[int]"), + ], + (("v1", "v2"), "Cannot unify TTuple with TList") + ; "kind mismatch" + )] + #[test_case(2, + &[ + ("v1", "Tuple[int]"), + ("v2", "Tuple[float]"), + ], + (("v1", "v2"), "Cannot unify objects with ID 0 and 1") + ; "tuple parameter mismatch" + )] + #[test_case(2, + &[ + ("v1", "Tuple[int,int]"), + ("v2", "Tuple[int]"), + ], + (("v1", "v2"), "Cannot unify tuples with length 1 and 2") + ; "tuple length mismatch" + )] + #[test_case(3, + &[ + ("v1", "Record[a=float,b=int]"), + ("v2", "Foo[v3]"), + ], + (("v1", "v2"), "No such attribute b") + ; "record obj merge" + )] + fn test_invalid_unification( + variable_count: u32, + unify_pairs: &[(&'static str, &'static str)], + errornous_pair: ((&'static str, &'static str), &'static str), + ) { + let mut env = TestEnvironment::new(); + let mut mapping = HashMap::new(); + for i in 1..=variable_count { + let v = env.get_fresh_var(); + mapping.insert(format!("v{}", i), v); + } + // unification may have side effect when we do type resolution, so freeze the types + // before doing unification. + let mut pairs = Vec::new(); + for (a, b) in unify_pairs.iter() { + let t1 = env.parse(a, &mapping); + let t2 = env.parse(b, &mapping); + pairs.push((t1, t2)); + } + let (t1, t2) = ( + env.parse(errornous_pair.0 .0, &mapping), + env.parse(errornous_pair.0 .1, &mapping), + ); + for (a, b) in pairs { + env.unifier.unify(a, b).unwrap(); + } + assert_eq!(env.unifier.unify(t1, t2), Err(errornous_pair.1.to_string())); + } +} diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index 1fb5b932f..03b3ae280 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -1,7 +1,7 @@ use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; use generational_arena::{Arena, Index}; use std::cell::RefCell; -use std::collections::BTreeMap; +use std::collections::HashMap; use std::iter::once; use std::mem::swap; use std::rc::Rc; @@ -18,10 +18,10 @@ use std::rc::Rc; // `--> TFunc #[derive(Copy, Clone, PartialEq, Eq, Debug)] -struct Type(u32); +pub struct Type(u32); #[derive(Copy, Clone, Debug, PartialEq, Eq)] -struct TypeIndex(Index); +pub struct TypeIndex(Index); impl UnifyValue for TypeIndex { type Error = NoError; @@ -48,19 +48,19 @@ impl UnifyKey for Type { } } -type Mapping = BTreeMap; -type VarMap = Mapping; +pub type Mapping = HashMap; +pub type VarMap = Mapping; #[derive(Clone)] -struct Call { +pub struct Call { posargs: Vec, - kwargs: BTreeMap, + kwargs: HashMap, ret: Type, - fn_id: usize, + fun: RefCell>, } #[derive(Clone)] -struct FuncArg { +pub struct FuncArg { name: String, ty: Type, is_optional: bool, @@ -69,7 +69,7 @@ struct FuncArg { // We use a lot of `Rc`/`RefCell`s here as we want to simplify our code. // We may not really need so much `Rc`s, but we would have to do complicated // stuffs otherwise. -enum TypeEnum { +pub enum TypeEnum { TVar { // TODO: upper/lower bound id: u32, @@ -95,7 +95,7 @@ enum TypeEnum { ty: Type, }, TCall { - calls: Vec, + calls: Vec>, }, TFunc { args: Vec, @@ -143,19 +143,40 @@ impl TypeEnum { } } -struct ObjDef { +pub struct ObjDef { name: String, fields: Mapping, } -struct Unifier { +pub struct Unifier { unification_table: RefCell>, type_arena: RefCell>>>, obj_def_table: Vec, } impl Unifier { - fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> { + pub fn new() -> Unifier { + Unifier { + unification_table: RefCell::new(InPlaceUnificationTable::new()), + type_arena: RefCell::new(Arena::new()), + obj_def_table: Vec::new(), + } + } + + pub fn add_ty(&self, a: TypeEnum) -> Type { + let index = self.type_arena.borrow_mut().insert(Rc::new(a.into())); + self.unification_table + .borrow_mut() + .new_key(TypeIndex(index)) + } + + pub fn get_ty(&self, a: Type) -> Rc> { + let mut table = self.unification_table.borrow_mut(); + let arena = self.type_arena.borrow(); + arena.get(table.probe_value(a).0).unwrap().clone() + } + + pub fn unify(&self, mut a: Type, mut b: Type) -> Result<(), String> { let (mut i_a, mut i_b) = { let mut table = self.unification_table.borrow_mut(); (table.probe_value(a), table.probe_value(b)) @@ -186,38 +207,21 @@ impl Unifier { self.occur_check(i_a, b)?; match &*ty_a { TypeEnum::TVar { .. } => { - match *ty_b { - TypeEnum::TVar { .. } => { - // TODO: type variables bound check - let old = { - let mut table = self.unification_table.borrow_mut(); - table.union(a, b); - if table.find(a) == a { - i_b - } else { - i_a - } - }; - self.type_arena.borrow_mut().remove(old.0); - } - _ => { - // TODO: type variables bound check - self.set_a_to_b(a, b); - } - } + // TODO: type variables bound check... + self.set_a_to_b(a, b); } TypeEnum::TSeq { map: map1 } => { match &*ty_b { - TypeEnum::TSeq { map: map2 } => { - drop(ty_a); - if let TypeEnum::TSeq { map: map1 } = &mut *ty_a_cell.as_ref().borrow_mut() + TypeEnum::TSeq { .. } => { + drop(ty_b); + if let TypeEnum::TSeq { map: map2 } = &mut *ty_b_cell.as_ref().borrow_mut() { // unify them to map1 - for (key, value) in map2.iter() { - if let Some(ty) = map1.get(key) { + for (key, value) in map1.iter() { + if let Some(ty) = map2.get(key) { self.unify(*ty, *value)?; } else { - map1.insert(*key, *value); + map2.insert(*key, *value); } } } else { @@ -277,16 +281,16 @@ impl Unifier { } TypeEnum::TRecord { fields: fields1 } => { match &*ty_b { - TypeEnum::TRecord { fields: fields2 } => { - drop(ty_a); - if let TypeEnum::TRecord { fields: fields1 } = - &mut *ty_a_cell.as_ref().borrow_mut() + TypeEnum::TRecord { .. } => { + drop(ty_b); + if let TypeEnum::TRecord { fields: fields2 } = + &mut *ty_b_cell.as_ref().borrow_mut() { - for (key, value) in fields2.iter() { - if let Some(ty) = fields1.get(key) { + for (key, value) in fields1.iter() { + if let Some(ty) = fields2.get(key) { self.unify(*ty, *value)?; } else { - fields1.insert(key.clone(), *value); + fields2.insert(key.clone(), *value); } } } else { @@ -341,6 +345,7 @@ impl Unifier { TypeEnum::TVirtual { ty: ty1 } => { if let TypeEnum::TVirtual { ty: ty2 } = &*ty_b { self.unify(*ty1, *ty2)?; + self.set_a_to_b(a, b); } else { return self.report_kind_error(&*ty_a, &*ty_b); } @@ -427,7 +432,7 @@ impl Unifier { Ok(()) } - fn subst(&self, a: Type, mapping: &VarMap) -> Option { + pub fn subst(&self, a: Type, mapping: &VarMap) -> Option { let index = self.unification_table.borrow_mut().probe_value(a); let ty_cell = { let arena = self.type_arena.borrow(); @@ -459,34 +464,14 @@ impl Unifier { new_ty.as_mut().unwrap()[i] = t1; } } - new_ty.map(|t| { - let index = self - .type_arena - .borrow_mut() - .insert(Rc::new(TypeEnum::TTuple { ty: t }.into())); - self.unification_table - .borrow_mut() - .new_key(TypeIndex(index)) - }) + new_ty.map(|t| self.add_ty(TypeEnum::TTuple { ty: t })) } - TypeEnum::TList { ty } => self.subst(*ty, mapping).map(|t| { - let index = self - .type_arena - .borrow_mut() - .insert(Rc::new(TypeEnum::TList { ty: t }.into())); - self.unification_table - .borrow_mut() - .new_key(TypeIndex(index)) - }), - TypeEnum::TVirtual { ty } => self.subst(*ty, mapping).map(|t| { - let index = self - .type_arena - .borrow_mut() - .insert(Rc::new(TypeEnum::TVirtual { ty: t }.into())); - self.unification_table - .borrow_mut() - .new_key(TypeIndex(index)) - }), + TypeEnum::TList { ty } => self + .subst(*ty, mapping) + .map(|t| self.add_ty(TypeEnum::TList { ty: t })), + TypeEnum::TVirtual { ty } => self + .subst(*ty, mapping) + .map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })), TypeEnum::TObj { obj_id, fields, @@ -508,23 +493,18 @@ impl Unifier { } }); if need_subst { - let index = self.type_arena.borrow_mut().insert(Rc::new( - TypeEnum::TObj { - obj_id: *obj_id, - params: self - .subst_map(¶ms, mapping) - .unwrap_or_else(|| params.clone()), - fields: self - .subst_map(&fields, mapping) - .unwrap_or_else(|| fields.clone()), - } - .into(), - )); - Some( - self.unification_table - .borrow_mut() - .new_key(TypeIndex(index)), - ) + let obj_id = *obj_id; + let params = self + .subst_map(¶ms, mapping) + .unwrap_or_else(|| params.clone()); + let fields = self + .subst_map(&fields, mapping) + .unwrap_or_else(|| fields.clone()); + Some(self.add_ty(TypeEnum::TObj { + obj_id, + params, + fields, + })) } else { None } @@ -546,19 +526,10 @@ impl Unifier { } } if new_params.is_some() || new_ret.is_some() || new_args.is_some() { - let index = self.type_arena.borrow_mut().insert(Rc::new( - TypeEnum::TFunc { - params: new_params.unwrap_or_else(|| params.clone()), - ret: new_ret.unwrap_or_else(|| *ret), - args: new_args.unwrap_or_else(|| args.clone()), - } - .into(), - )); - Some( - self.unification_table - .borrow_mut() - .new_key(TypeIndex(index)), - ) + let params = new_params.unwrap_or_else(|| params.clone()); + let ret = new_ret.unwrap_or_else(|| *ret); + let args = new_args.unwrap_or_else(|| args.clone()); + Some(self.add_ty(TypeEnum::TFunc { params, ret, args })) } else { None } @@ -569,7 +540,7 @@ impl Unifier { fn subst_map(&self, map: &Mapping, mapping: &VarMap) -> Option> where - K: std::cmp::Ord + std::clone::Clone, + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, { let mut map2 = None; for (k, v) in map.iter() { @@ -582,4 +553,74 @@ impl Unifier { } map2 } + + pub fn eq(&self, a: Type, b: Type) -> bool { + if a == b { + return true; + } + let (i_a, i_b) = { + let mut table = self.unification_table.borrow_mut(); + (table.probe_value(a), table.probe_value(b)) + }; + + if i_a == i_b { + return true; + } + + let (ty_a, ty_b) = { + let arena = self.type_arena.borrow(); + ( + arena.get(i_a.0).unwrap().clone(), + arena.get(i_b.0).unwrap().clone(), + ) + }; + + let ty_a = ty_a.borrow(); + let ty_b = ty_b.borrow(); + + match (&*ty_a, &*ty_b) { + (TypeEnum::TVar { id: id1 }, TypeEnum::TVar { id: id2 }) => id1 == id2, + (TypeEnum::TSeq { map: map1 }, TypeEnum::TSeq { map: map2 }) => self.map_eq(map1, map2), + (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { + ty1.len() == ty2.len() + && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) + } + (TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) + | (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => { + self.eq(*ty1, *ty2) + } + (TypeEnum::TRecord { fields: fields1 }, TypeEnum::TRecord { fields: fields2 }) => { + self.map_eq(fields1, fields2) + } + ( + TypeEnum::TObj { + obj_id: id1, + params: params1, + .. + }, + TypeEnum::TObj { + obj_id: id2, + params: params2, + .. + }, + ) => id1 == id2 && self.map_eq(params1, params2), + // TCall and TFunc are not yet implemented + _ => false, + } + } + + fn map_eq(&self, map1: &Mapping, map2: &Mapping) -> bool + where + K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, + { + if map1.len() != map2.len() { + return false; + } + for (k, v) in map1.iter() { + if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) { + return false; + } + } + true + } }